儲存網絡結構以及參數
- 一、儲存方式
- 二、pkl、pth檔案差別
-
-
- 2.1 .pkl檔案
- 2.2 .pth檔案
-
對于pytorch儲存網絡參數,大家一般可以看到有 .pkl檔案 以及 .pth檔案,對于這兩者有什麼差別,以及如何儲存網絡參數等,本文就好好講述一下。
一、儲存方式
首先我們知道不論是儲存模型還是參數都需要用到
torch.save()
。
對于
torch.save()
有兩種儲存方式:
- 隻儲存神經網絡的訓練模型的參數,save的對象是
;model.state_dict()
- 既儲存整個神經網絡的的模型結構又儲存模型參數,那麼save的對象就是整個模型;
Eg. 假設我有一個訓練好的模型名叫model,如何來儲存參數以及結構?
import torch
# 儲存模型步驟
torch.save(model, 'net.pth') # 儲存整個神經網絡的模型結構以及參數
torch.save(model, 'net.pkl') # 同上
torch.save(model.state_dict(), 'net_params.pth') # 隻儲存模型參數
torch.save(model.state_dict(), 'net_params.pkl') # 同上
# 加載模型步驟
model = torch.load('net.pth') # 加載整個神經網絡的模型結構以及參數
model = torch.load('net.pkl') # 同上
model.load_state_dict(torch.load('net_params.pth')) # 僅加載參數
model.load_state_dict(torch.load('net_params.pkl')) # 同上
上面例子也可以看出若使用
torch.save()
來進行模型參數的儲存,那儲存檔案的字尾其實沒有任何影響,.pkl 檔案和 .pth 檔案一模一樣。
二、pkl、pth檔案差別
實際上,這兩種格式的檔案還是有差別的。
2.1 .pkl檔案
首先介紹 .pkl 檔案,它若直接打開會顯示一堆序列化的東西,以二進制形式存儲的。如果去 read 這些檔案,需要用
'rb'
而不是
'r'
模式。
import pickle as pkl
file = os.path.join('annot',model.pkl) # 打開pkl檔案
with open(file, 'rb') as anno_file:
result = pkl.load(anno_file)
或者:
import pickle as pkl
file = os.path.join('annot',model.pkl) # 打開pkl檔案
anno_file = open(file, 'rb')
result = pkl.load(anno_file)
2.2 .pth檔案
import torch
filename = r'E:\anaconda\model.pth' # 字元串前面加r,表示的意思是禁止字元串轉義
model = torch.load(filename)
print(model)
但其實不管pkl檔案還是pth檔案,都是以二進制形式存儲的,沒有本質上的差別,你用pickle這個庫去加載pkl檔案或pth檔案,效果都是一樣的。