Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)
目錄
- 1. 準備資料集
- 1.1 MNIST資料集擷取:
- 1.2 程式部分
- 2. 設計網絡結構
- 2.1 網絡設計
- 2.2 程式部分
- 3. 疊代訓練
- 4. 測試集預測部分
- 5. 全部代碼
1. 準備資料集
1.1 MNIST資料集擷取:
- torchvision.datasets接口直接下載下傳,該接口可以直接建構資料集,推薦
- 其他途徑下載下傳後,編寫程式進行讀取,然後由Datasets建構自己的資料集
本文使用第一種方法擷取資料集,并使用Dataloader進行按批裝載。如果使用程式下載下傳失敗,請将其他途徑下載下傳的MNIST資料集 [檔案] 和 [解壓檔案] 放置在 <data/MNIST/raw/> 位置下,本文的程式及檔案結構圖如下:

其中,model檔案夾用來存儲每個epoch訓練的模型參數,根檔案夾下包含model.py用于訓練模型,test.py為測試集測試,show.py為展示部分
1.2 程式部分
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
# 1. 準備資料集
## 1.1 使用torchvision自動下載下傳MNIST資料集
train_data = datasets.MNIST(root='data\\',
train=True,
transform=transforms.ToTensor(),
download=True)
## 1.2 建構資料集裝載器
train_loader = DataLoader(dataset=train_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4)
if __name__ == "__main__":
print("===============資料統計===============")
print("訓練集樣本:",train_data.__len__(), train_data.data.shape)
【代碼解析】
- root為存放MNIST的路徑,trian=True代表下載下傳的為訓練集和訓練集标簽,False則代表測試集和标簽
- transforms.ToTensor()表示将shape為(H, W, C)的 numpy 數組或 img 轉為shape為(C, H, W)的tensor,并将數值歸一化為[0,1]
- download為True則代表自動下載下傳,若該檔案夾下已經下載下傳,則直接跳過下載下傳步驟
- shuffle=True,表示對分好的batch進行洗牌操作,drop_last=True表示對最後不足batch大小的剩餘樣本舍去,False表示保留
- num_works表示每次讀取的程序數,和核心數有關
Dataset和Dataloader詳細說明,請移步:[Pytorch Dataset和Dataloader 學習筆記(二)]
2. 設計網絡結構
2.1 網絡設計
網絡結構如上圖所示,輸入圖像—>卷積1—>池化1—>卷積2—>池化2—>全連接配接1—>全連接配接2—>softmax,每次卷積通道數都增加一倍,最後送入全連接配接層實作分類
2.2 程式部分
# 2. Design model using class
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_layer1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.max_pooling1 = nn.MaxPool2d(2)
self.conv_layer2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.max_pooling2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1568, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.max_pooling1(F.relu(self.conv_layer1(x)))
x = self.max_pooling2(F.relu(self.conv_layer2(x)))
x = x.view(-1, 32*7*7)
x = F.relu(self.fc1(x))
y_hat = self.fc2(x) # CrossEntropyLoss會自動激活最後一層的輸出以及softmax處理
return y_hat
net = Net()
# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
【代碼解析】
- fc1的1568次元是因為最後一次池化後的shape為32*7*7=1568
- 在最後一層,并沒有進行relu激活以及接入softmax,是因為,在CrossEntropyLoss中會自動激活最後一層的輸出以及softmax處理
CrossEntropyLoss圖參考:《PyTorch深度學習實踐》完結合集
詳細網絡結構搭建說明,請移步:Pytorch線性規劃模型 學習筆記(一)
3. 疊代訓練
# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
# 4. Training
if __name__ == "__main__":
print("Training...")
for epoch in range(20):
strat = time.time()
total_correct = 0
for x, y in train_loader:
y_hat = net(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y)) # 統計目前epoch下的正确個數
loss = criterion(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (float(total_correct) / train_data.__len__())*100
save_path = "model/net" + str(epoch+1) + ".pth"
torch.save(obj=net.state_dict(), f=save_path)
print("epoch:", str(epoch + 1) + "/20",
" \n time:", "%.1f" % (time.time() - strat) + "s"
" train_loss:", loss.item(),
" acc:%.3f%%" % acc,)
print("we are done!")
【代碼解析】
- total_correct變量用于統計每個epoch下正确預測值的個數,每進行epoch進行一次清零
- torch.argmax(y_hat, dim=1)用于選取y_hat下每一行的最大值(每個樣本的最高得分),并傳回與y相同次元的tensor
- torch.eq(y_pre, y)用于比較兩個矩陣元素是否相同,相同則傳回True,不同則傳回False,用于判斷預測值與真實值是否相同
- torch.save儲存了每個epoch的網絡權重參數
4. 測試集預測部分
# 測試模型,測試集為test_data
import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import Net
test_data = datasets.MNIST(root='data\\',
train=False,
transform=transforms.ToTensor(),
download=True)
test_loader = DataLoader(dataset=test_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4)
if __name__ == "__main__":
print("---------------預測分析---------------")
print("測試集樣本:", test_data.__len__(), test_data.data.shape)
model = Net()
model.load_state_dict(torch.load("model/net20.pth"))
model.eval()
total_correct = 0
for x, y in test_loader:
y_hat = model(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y))
acc = (float(total_correct) / test_data.__len__())*100
print("total_test_samples:", test_data.__len__(),
" test_acc:", "%.3f%%" % acc)
經過20個epoch的訓練,在測試集上達到了98.590%的準确率,部分batch真實值與預測值展示如下:
5. 全部代碼
連結:連結:https://pan.baidu.com/s/1GGhG1Slw2Tlsgl13yzHUIw
提取碼:82l4
轉載請說明出處