天天看點

Pytorch-模型儲存與加載

目錄

1、模型儲存

2、模型加載

3、斷點續訓練

1、模型儲存

方法一:儲存整個Module:torch.save(net,path)

方法二:儲存模型參數:

state_dict = net.state_dict()

torch.save(state_dict,path)

import torch
import numpy as np
import torch.nn as nn

# 這裡簡單建立一個LeNet2網絡模型
class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(2020914)
            
# 模型儲存
net = LeNet2(classes=2020)

# "訓練"
print("訓練前: ", net.features[0].weight[0, ...])
net.initialize()
print("訓練後: ", net.features[0].weight[0, ...])

# 開始儲存整個模型
# 這裡直接儲存在對應的model檔案夾中
torch.save(net,"./model/model.pkl")

# 儲存模型的參數
state_dict = net.state_dict()
torch.save(state_dict,'./model/model_state_dict.pkl')
           

然後就會在對應檔案夾中生成對應的檔案

Pytorch-模型儲存與加載

2、模型加載

就是加載我們上述儲存的網絡模型

net = torch.load(model_path)

# 模型加載
net_load = torch.load("./model/model.pkl")
print(net_load)
           
LeNet2(
  (features): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=2020, bias=True)
  )
)      

加載第一步中儲存的模型參數

state_dict_load = torch.load(path_state_dict)

state_dict_load = torch.load("./model/model_state_dict.pkl")
print(state_dict_load)
           
OrderedDict([('features.0.weight', tensor([[[[2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.]],

         [[2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.]],

         [[2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.],
          [2020914., 2020914., 2020914., 2020914., 2020914.]]],............      

隻是加載了上一次儲存的參數,那麼如何進行參數更新呢?

# 建立一個新的網絡
net_new = LeNet2(classes=2020)
print("加載前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加載後: ", net_new.features[0].weight[0, ...])
           
加載前:  tensor([[[-3.9705e-02,  7.5627e-02, -8.3019e-02,  3.5177e-02,  6.7630e-02],
         [ 4.4209e-02, -4.2972e-02, -1.0790e-01,  4.8304e-03,  1.1456e-01],
         [-3.5768e-03, -8.3255e-02, -1.0525e-02,  8.4690e-02, -2.5149e-02],
         [-3.7954e-02,  8.5832e-03, -7.2547e-02, -1.0731e-01,  3.5220e-02],
         [-1.0719e-01,  8.4218e-02,  6.8099e-02, -8.5679e-02, -6.9766e-02]],

        [[-4.3151e-02,  6.2990e-02, -7.2684e-02,  6.7785e-02,  1.0125e-02],
         [-5.2818e-02,  9.3259e-02, -4.1017e-04, -1.0336e-01, -8.8489e-02],
         [-6.2203e-02, -5.9651e-02,  1.9473e-02, -1.1111e-01, -1.0471e-01],
         [-1.1488e-02,  1.1124e-01,  5.3528e-03, -2.4913e-02, -1.0143e-01],
         [ 1.7090e-03,  1.1159e-01, -4.1312e-02,  9.2102e-03,  8.9532e-02]],

        [[ 1.8905e-02, -7.9735e-02, -5.6365e-02,  4.9867e-02, -1.2206e-02],
         [-4.1118e-02, -8.3310e-02, -7.8296e-02,  7.2381e-02, -7.9311e-05],
         [ 9.2661e-02, -7.5984e-03,  6.1938e-02,  4.9871e-03,  3.9456e-02],
         [ 5.3209e-02, -1.3996e-02, -1.1026e-01, -7.9629e-03, -6.1041e-02],
         [-8.8979e-02, -3.8268e-02,  5.3847e-03, -9.7152e-02, -8.5485e-02]]],
       grad_fn=<SelectBackward>)
加載後:  tensor([[[2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.]],

        [[2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.]],

        [[2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.],
         [2020914., 2020914., 2020914., 2020914., 2020914.]]],
       grad_fn=<SelectBackward>)      

除了上述的儲存和加載方式,其實一種比較人性化的儲存方法:斷點續訓練

Pytorch-模型儲存與加載

3、斷點續訓練

 如果模型在訓練過程中突然出現意外,進行斷點儲存

# 每5個批次儲存一次
checkpoint_interval = 5
if (epoch+1) % checkpoint_interval == 0:
    checkpoint = {"model_state_dict":net.state_dict(),
                 "optimizer_state_dict":optimizer.state_dict(),
                 "epoch":epoch}
    path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
    torch.save(checkpoint, path_checkpoint)
           

斷點續訓練

# 訓練前加載
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch
           

繼續閱讀