天天看點

PyTorch儲存網絡結構以及參數【 torch.save()、torch.load() 】 一、儲存方式 二、pkl、pth檔案差別

儲存網絡結構以及參數

  • 一、儲存方式
  • 二、pkl、pth檔案差別
      • 2.1 .pkl檔案
      • 2.2 .pth檔案

       對于pytorch儲存網絡參數,大家一般可以看到有 .pkl檔案 以及 .pth檔案,對于這兩者有什麼差別,以及如何儲存網絡參數等,本文就好好講述一下。

一、儲存方式

       首先我們知道不論是儲存模型還是參數都需要用到

torch.save()

對于

torch.save()

有兩種儲存方式:

  • 隻儲存神經網絡的訓練模型的參數,save的對象是

    model.state_dict()

  • 既儲存整個神經網絡的的模型結構又儲存模型參數,那麼save的對象就是整個模型;

Eg. 假設我有一個訓練好的模型名叫model,如何來儲存參數以及結構?

import torch

# 儲存模型步驟
torch.save(model, 'net.pth')  # 儲存整個神經網絡的模型結構以及參數
torch.save(model, 'net.pkl')  # 同上
torch.save(model.state_dict(), 'net_params.pth')  # 隻儲存模型參數
torch.save(model.state_dict(), 'net_params.pkl')  # 同上

# 加載模型步驟
model = torch.load('net.pth')  # 加載整個神經網絡的模型結構以及參數
model = torch.load('net.pkl')  # 同上
model.load_state_dict(torch.load('net_params.pth')) # 僅加載參數
model.load_state_dict(torch.load('net_params.pkl')) # 同上
           

       上面例子也可以看出若使用

torch.save()

來進行模型參數的儲存,那儲存檔案的字尾其實沒有任何影響,.pkl 檔案和 .pth 檔案一模一樣。

二、pkl、pth檔案差別

實際上,這兩種格式的檔案還是有差別的。

2.1 .pkl檔案

       首先介紹 .pkl 檔案,它若直接打開會顯示一堆序列化的東西,以二進制形式存儲的。如果去 read 這些檔案,需要用

'rb'

而不是

'r'

模式。

import pickle as pkl

file = os.path.join('annot',model.pkl) # 打開pkl檔案
with open(file, 'rb') as anno_file:
	result = pkl.load(anno_file)
           

或者:

import pickle as pkl

file = os.path.join('annot',model.pkl) # 打開pkl檔案
anno_file = open(file, 'rb')
result = pkl.load(anno_file)
           

2.2 .pth檔案

import torch

filename = r'E:\anaconda\model.pth' # 字元串前面加r,表示的意思是禁止字元串轉義
model = torch.load(filename)
print(model)
           

但其實不管pkl檔案還是pth檔案,都是以二進制形式存儲的,沒有本質上的差別,你用pickle這個庫去加載pkl檔案或pth檔案,效果都是一樣的。

繼續閱讀