在辛辛苦苦訓練好模型之後,我們想将它儲存起來,或我們想使用已經訓練完成的模型。那麼該如何是實作呢?
本文參考:https://pytorch.org/tutorials/beginner/saving_loading_models.html
本文将以一個CNN模型示範如何儲存或加載以訓練好的模型。
首先給訓練過程:
import torch
import torch.nn
import torch.optim
import torch.utils.data
import torchvision
import numpy
import matplotlib
from torch.autograd import Variable
from torchvision import datasets
from torchvision import transforms
num_epoch = 5
batch_size = 100
learning_rate = 0.001
# -------------------------------------------------------------------------
# root 用于指定資料集在下載下傳之後的存放路徑
# transform 用于指定導入資料集需要對資料進行那種變化操作
# train是指定在資料集下載下傳完成後需要載入那部分資料,
# 如果設定為True 則說明載入的是該資料集的訓練集部分
# 如果設定為FALSE 則說明載入的是該資料集的測試集部分
data_train = datasets.MNIST(root="./data/",
transform=transforms.ToTensor(),
train=True,
download=True)
data_test = datasets.MNIST(root="./data/",
transform=transforms.ToTensor(),
train=False)
# ______________________________________________________________________________
# 下面對資料進行裝載,我們可以将資料的載入了解為對圖檔的處理,
# 在處理完成後,我們就需要将這些圖檔打包好送給我們的模型進行訓練 了 而裝載就是這個打包的過程
# dataset 參數用于指定我們載入的資料集名稱
# batch_size參數設定了每個包中的圖檔資料個數
# 在裝載的過程會将資料随機打亂順序并進打包
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
batch_size=batch_size,
shuffle=True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
batch_size=batch_size,
shuffle=True)
class CNN(torch.nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 16, kernel_size=5, padding=2),
# 用于搭建卷積神經網絡的卷積層,主要的輸入參數有輸入通道數、輸出通道數、
# 卷積核大小、卷積核移動步長和Paddingde值。其中,輸入通道數的資料類型是
# 整型,用于确定輸入資料的層數;輸出通道數的資料類型也是整型,用于确定
# 輸出資料的層數;卷積核大小的資料類型是整型,用于确定卷積核的大小;
# 卷積核移動步長的資料類型是整型,用于确定卷積核每次滑動的步長;
# Paddingde 的資料類型是整型,值為0時表示不進行邊界像素的填充,
# 如果值大于0,那麼增加數字所對應的邊界像素層數。
torch.nn.BatchNorm2d(16),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
# 用于實作卷積神經網絡中的最大池化層,主要的輸入參數是池化視窗大小、
# 池化視窗移動步長和Padding的值。
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(16, 32, kernel_size=5, padding=2),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.fc = torch.nn.Linear(7 * 7 * 32, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
cnn = CNN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cnn.to(device)
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
for epoch in range(num_epoch):
for i, data in enumerate(data_loader_train):
images, labels = data[0].to(device), data[1].to(device)
outputs = cnn(images)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 100 == 0:
print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
% (epoch + 1, num_epoch, i + 1, len(data_train), loss.item()))
- 方法一(推薦)
訓練完成之後:使用如下的方式儲存模型:
可以看到在同一級目錄下出現了一個名為cnn.pkl的檔案這個就是模型本尊了。
這種方法僅僅儲存了模型的重要參數而非整個模型
通過運作如下代碼可以顯示cnn.state_dict()中包含了那些内容(這裡由于篇幅的原因,隻顯示了大小,沒有顯示具體數值):
print("Model's state_dict:")
for param_tensor in cnn.state_dict():
print(param_tensor, "\t", cnn.state_dict()[param_tensor].size())
輸出結果為:
Model's state_dict:
conv1.0.weight torch.Size([16, 1, 5, 5])
conv1.0.bias torch.Size([16])
conv1.1.weight torch.Size([16])
conv1.1.bias torch.Size([16])
conv1.1.running_mean torch.Size([16])
conv1.1.running_var torch.Size([16])
conv1.1.num_batches_tracked torch.Size([])
conv2.0.weight torch.Size([32, 16, 5, 5])
conv2.0.bias torch.Size([32])
conv2.1.weight torch.Size([32])
conv2.1.bias torch.Size([32])
conv2.1.running_mean torch.Size([32])
conv2.1.running_var torch.Size([32])
conv2.1.num_batches_tracked torch.Size([])
fc.weight torch.Size([10, 1568])
fc.bias torch.Size([10])
可以發現這些都是學習的參數資訊。
加載這種方式儲存得模型時,使用如下的方式:
cnn_new=CNN()
cnn_new.load_state_dict(torch.load('cnn.pkl'))
cnn_new.eval()
必須調用model.eval()将dropout和批處理規範化層設定為評估模式。
-
方法二
還有一種方法可以儲存整個模型
cnn_new=CNN()
cnn_new= torch.load('cnn.pkl')
cnn_new.eval()
- 方法三:
儲存多個通用檢查點
儲存
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
},'cnn.pkl')
加載
cnn_new=CNN()
optimizer_new=torch.optim.Adam(cnn.parameters(), lr=learning_rate)
checkpoint = torch.load('cnn.pkl')
cnn_new.load_state_dict(checkpoint['model_state_dict'])
optimizer_new.load_state_dict(checkpoint['optimizer_state_dict'])
epoch_new = checkpoint['epoch']
loss_new = checkpoint['loss']
cnn_new.eval()
官網還給出了很多方式,本文會在實際操作之後再更新文章。