天天看點

深度學習--第15篇: Pytorch儲存和加載模型參數參考部落格1. 儲存模型和參數2. 僅儲存參數3. 加載pytorch預訓練模型

Pytorch儲存和加載模型參數

  • 參考部落格
  • 1. 儲存模型和參數
  • 2. 僅儲存參數
  • 3. 加載pytorch預訓練模型
    • 3.1 加載預訓練模型和參數
    • 3.2 隻加載模型不加載預訓練參數

參考部落格

參考部落格: https://blog.csdn.net/lscelory/article/details/81482586

pytorch的模型和參數是分開的,可以分别儲存或加載模型和參數。

pytorch有兩種模型儲存方式:

  • 儲存整個神經網絡的的結構資訊和模型參數資訊,save的對象是網絡net
  • 隻儲存神經網絡的訓練模型參數,save的對象是net.state_dict()

對應兩種儲存模型的方式,pytorch也有兩種加載模型的方式。對應第一種儲存方式,加載模型時通過torch.load(’.pth’)直接初始化新的神經網絡對象;對應第二種儲存方式,需要首先導入對應的網絡,再通過net.load_state_dict(torch.load(’.pth’))完成模型參數的加載。

在網絡比較大的時候,第一種方法會花費較多的時間。

1. 儲存模型和參數

  • 儲存模型
# 将網絡結構和模型參數都儲存起來,在測試時可以直接加載,不需要初始化網絡結構
torch.save(model, path)

參數:
	model: 訓練的網絡
	pth: 儲存的路徑(包含檔案名,字尾名以.pth .pkl 等結尾)

執行個體:
torch.save(model, os.path.join('.', 'lenet.pth')) # 儲存模型結構和參數
           
  • 加載模型
# 直接加載模型檔案,不需要初始化網絡結構
model = torch.load(path)

參數:
	model: 加載後的網絡
	pth: 儲存模型檔案的路徑(包含檔案名,字尾名以.pth .pkl 等結尾)

執行個體:
model = torch.load(os.path.join('.', 'lenet.pth'))# 加載模型結構和參數
           

2. 僅儲存參數

  • 儲存參數
# 将lenet模型儲存為lenet.pth, 注意儲存的僅僅是網絡模型的狀态資訊參數字典, 加載是需要初始化網絡模型
torch.save(net.state_dict(), os.path.join('.', 'lenet.pth'))
           
  • 加載參數
# 加載lenet,模型存放在lenet.pth, 加載之前要确認網絡模型已初始化完成
model = torch.load(os.path.join('.', 'lenet.pth'))
net.load_state_dict(model)
           

3. 加載pytorch預訓練模型

3.1 加載預訓練模型和參數

import torchvision
AlexNet = torchvision.models.alexnet(pretrained=True) # 加載預訓練模型AlexNet和參數

resnet18 = torchvision.models.resnet18(pretrained=True)
           

3.2 隻加載模型不加載預訓練參數

import torchvision
AlexNet = torchvision.models.alexnet(pretrained=False) # 加載預訓練模型AlexNet

# 導入模型結構
ResNet18 = models.resnet18(pretrained=False)
# 加載預先下載下傳好的預訓練參數到resnet18
ResNet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))
           

繼續閱讀