目錄
1、模型儲存
2、模型加載
3、斷點續訓練
1、模型儲存
方法一:儲存整個Module:torch.save(net,path)
方法二:儲存模型參數:
state_dict = net.state_dict()
torch.save(state_dict,path)
import torch
import numpy as np
import torch.nn as nn
# 這裡簡單建立一個LeNet2網絡模型
class LeNet2(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
def initialize(self):
for p in self.parameters():
p.data.fill_(2020914)
# 模型儲存
net = LeNet2(classes=2020)
# "訓練"
print("訓練前: ", net.features[0].weight[0, ...])
net.initialize()
print("訓練後: ", net.features[0].weight[0, ...])
# 開始儲存整個模型
# 這裡直接儲存在對應的model檔案夾中
torch.save(net,"./model/model.pkl")
# 儲存模型的參數
state_dict = net.state_dict()
torch.save(state_dict,'./model/model_state_dict.pkl')
然後就會在對應檔案夾中生成對應的檔案
2、模型加載
就是加載我們上述儲存的網絡模型
net = torch.load(model_path)
# 模型加載
net_load = torch.load("./model/model.pkl")
print(net_load)
LeNet2(
(features): Sequential(
(0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): ReLU()
(2): Linear(in_features=120, out_features=84, bias=True)
(3): ReLU()
(4): Linear(in_features=84, out_features=2020, bias=True)
)
)
加載第一步中儲存的模型參數
state_dict_load = torch.load(path_state_dict)
state_dict_load = torch.load("./model/model_state_dict.pkl")
print(state_dict_load)
OrderedDict([('features.0.weight', tensor([[[[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.]],
[[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.]],
[[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.]]],............
隻是加載了上一次儲存的參數,那麼如何進行參數更新呢?
# 建立一個新的網絡
net_new = LeNet2(classes=2020)
print("加載前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加載後: ", net_new.features[0].weight[0, ...])
加載前: tensor([[[-3.9705e-02, 7.5627e-02, -8.3019e-02, 3.5177e-02, 6.7630e-02],
[ 4.4209e-02, -4.2972e-02, -1.0790e-01, 4.8304e-03, 1.1456e-01],
[-3.5768e-03, -8.3255e-02, -1.0525e-02, 8.4690e-02, -2.5149e-02],
[-3.7954e-02, 8.5832e-03, -7.2547e-02, -1.0731e-01, 3.5220e-02],
[-1.0719e-01, 8.4218e-02, 6.8099e-02, -8.5679e-02, -6.9766e-02]],
[[-4.3151e-02, 6.2990e-02, -7.2684e-02, 6.7785e-02, 1.0125e-02],
[-5.2818e-02, 9.3259e-02, -4.1017e-04, -1.0336e-01, -8.8489e-02],
[-6.2203e-02, -5.9651e-02, 1.9473e-02, -1.1111e-01, -1.0471e-01],
[-1.1488e-02, 1.1124e-01, 5.3528e-03, -2.4913e-02, -1.0143e-01],
[ 1.7090e-03, 1.1159e-01, -4.1312e-02, 9.2102e-03, 8.9532e-02]],
[[ 1.8905e-02, -7.9735e-02, -5.6365e-02, 4.9867e-02, -1.2206e-02],
[-4.1118e-02, -8.3310e-02, -7.8296e-02, 7.2381e-02, -7.9311e-05],
[ 9.2661e-02, -7.5984e-03, 6.1938e-02, 4.9871e-03, 3.9456e-02],
[ 5.3209e-02, -1.3996e-02, -1.1026e-01, -7.9629e-03, -6.1041e-02],
[-8.8979e-02, -3.8268e-02, 5.3847e-03, -9.7152e-02, -8.5485e-02]]],
grad_fn=<SelectBackward>)
加載後: tensor([[[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.]],
[[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.]],
[[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.],
[2020914., 2020914., 2020914., 2020914., 2020914.]]],
grad_fn=<SelectBackward>)
除了上述的儲存和加載方式,其實一種比較人性化的儲存方法:斷點續訓練
3、斷點續訓練
如果模型在訓練過程中突然出現意外,進行斷點儲存
# 每5個批次儲存一次
checkpoint_interval = 5
if (epoch+1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict":net.state_dict(),
"optimizer_state_dict":optimizer.state_dict(),
"epoch":epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
斷點續訓練
# 訓練前加載
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch