Pytorch儲存和加載模型參數
- 參考部落格
- 1. 儲存模型和參數
- 2. 僅儲存參數
- 3. 加載pytorch預訓練模型
-
- 3.1 加載預訓練模型和參數
- 3.2 隻加載模型不加載預訓練參數
參考部落格
參考部落格: https://blog.csdn.net/lscelory/article/details/81482586
pytorch的模型和參數是分開的,可以分别儲存或加載模型和參數。
pytorch有兩種模型儲存方式:
- 儲存整個神經網絡的的結構資訊和模型參數資訊,save的對象是網絡net
- 隻儲存神經網絡的訓練模型參數,save的對象是net.state_dict()
對應兩種儲存模型的方式,pytorch也有兩種加載模型的方式。對應第一種儲存方式,加載模型時通過torch.load(’.pth’)直接初始化新的神經網絡對象;對應第二種儲存方式,需要首先導入對應的網絡,再通過net.load_state_dict(torch.load(’.pth’))完成模型參數的加載。
在網絡比較大的時候,第一種方法會花費較多的時間。
1. 儲存模型和參數
- 儲存模型
# 将網絡結構和模型參數都儲存起來,在測試時可以直接加載,不需要初始化網絡結構
torch.save(model, path)
參數:
model: 訓練的網絡
pth: 儲存的路徑(包含檔案名,字尾名以.pth .pkl 等結尾)
執行個體:
torch.save(model, os.path.join('.', 'lenet.pth')) # 儲存模型結構和參數
- 加載模型
# 直接加載模型檔案,不需要初始化網絡結構
model = torch.load(path)
參數:
model: 加載後的網絡
pth: 儲存模型檔案的路徑(包含檔案名,字尾名以.pth .pkl 等結尾)
執行個體:
model = torch.load(os.path.join('.', 'lenet.pth'))# 加載模型結構和參數
2. 僅儲存參數
- 儲存參數
# 将lenet模型儲存為lenet.pth, 注意儲存的僅僅是網絡模型的狀态資訊參數字典, 加載是需要初始化網絡模型
torch.save(net.state_dict(), os.path.join('.', 'lenet.pth'))
- 加載參數
# 加載lenet,模型存放在lenet.pth, 加載之前要确認網絡模型已初始化完成
model = torch.load(os.path.join('.', 'lenet.pth'))
net.load_state_dict(model)
3. 加載pytorch預訓練模型
3.1 加載預訓練模型和參數
import torchvision
AlexNet = torchvision.models.alexnet(pretrained=True) # 加載預訓練模型AlexNet和參數
resnet18 = torchvision.models.resnet18(pretrained=True)
3.2 隻加載模型不加載預訓練參數
import torchvision
AlexNet = torchvision.models.alexnet(pretrained=False) # 加載預訓練模型AlexNet
# 導入模型結構
ResNet18 = models.resnet18(pretrained=False)
# 加載預先下載下傳好的預訓練參數到resnet18
ResNet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))