天天看點

Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)

Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)

目錄

  • 1. 準備資料集
    • 1.1 MNIST資料集擷取:
    • 1.2 程式部分
  • 2. 設計網絡結構
    • 2.1 網絡設計
    • 2.2 程式部分
  • 3. 疊代訓練
  • 4. 測試集預測部分
  • 5. 全部代碼

1. 準備資料集

1.1 MNIST資料集擷取:

  • torchvision.datasets接口直接下載下傳,該接口可以直接建構資料集,推薦
  • 其他途徑下載下傳後,編寫程式進行讀取,然後由Datasets建構自己的資料集

​ ​ 本文使用第一種方法擷取資料集,并使用Dataloader進行按批裝載。如果使用程式下載下傳失敗,請将其他途徑下載下傳的MNIST資料集 [檔案] 和 [解壓檔案] 放置在 <data/MNIST/raw/> 位置下,本文的程式及檔案結構圖如下:

Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)

​ ​ 其中,model檔案夾用來存儲每個epoch訓練的模型參數,根檔案夾下包含model.py用于訓練模型,test.py為測試集測試,show.py為展示部分

1.2 程式部分

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time

# 1. 準備資料集
## 1.1 使用torchvision自動下載下傳MNIST資料集
train_data = datasets.MNIST(root='data\\',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True)

## 1.2 建構資料集裝載器
train_loader = DataLoader(dataset=train_data,
                          batch_size=100,
                          shuffle=True,
                          drop_last=False,
                          num_workers=4)

if __name__ == "__main__":
    print("===============資料統計===============")
    print("訓練集樣本:",train_data.__len__(), train_data.data.shape)
           
Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)

​ ​ 【代碼解析】

  • root為存放MNIST的路徑,trian=True代表下載下傳的為訓練集和訓練集标簽,False則代表測試集和标簽
  • transforms.ToTensor()表示将shape為(H, W, C)的 numpy 數組或 img 轉為shape為(C, H, W)的tensor,并将數值歸一化為[0,1]
  • download為True則代表自動下載下傳,若該檔案夾下已經下載下傳,則直接跳過下載下傳步驟
  • shuffle=True,表示對分好的batch進行洗牌操作,drop_last=True表示對最後不足batch大小的剩餘樣本舍去,False表示保留
  • num_works表示每次讀取的程序數,和核心數有關

​ ​ Dataset和Dataloader詳細說明,請移步:[Pytorch Dataset和Dataloader 學習筆記(二)]

2. 設計網絡結構

2.1 網絡設計

Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)

​ ​ 網絡結構如上圖所示,輸入圖像—>卷積1—>池化1—>卷積2—>池化2—>全連接配接1—>全連接配接2—>softmax,每次卷積通道數都增加一倍,最後送入全連接配接層實作分類

2.2 程式部分

# 2. Design model using class
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_layer1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.max_pooling1 = nn.MaxPool2d(2)
        self.conv_layer2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.max_pooling2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(1568, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.max_pooling1(F.relu(self.conv_layer1(x)))
        x = self.max_pooling2(F.relu(self.conv_layer2(x)))
        x = x.view(-1, 32*7*7)
        x = F.relu(self.fc1(x))
        y_hat = self.fc2(x)     # CrossEntropyLoss會自動激活最後一層的輸出以及softmax處理
        return y_hat

net = Net()

# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
           
Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)

​ ​ 【代碼解析】

  • fc1的1568次元是因為最後一次池化後的shape為32*7*7=1568
  • 在最後一層,并沒有進行relu激活以及接入softmax,是因為,在CrossEntropyLoss中會自動激活最後一層的輸出以及softmax處理
Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)

​ ​ CrossEntropyLoss圖參考:《PyTorch深度學習實踐》完結合集

​ ​ 詳細網絡結構搭建說明,請移步:Pytorch線性規劃模型 學習筆記(一)

3. 疊代訓練

# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

# 4. Training
if __name__ == "__main__":
    print("Training...")
    for epoch in range(20):
        strat = time.time()
        total_correct = 0
        for x, y in train_loader:
            y_hat = net(x)
            y_pre = torch.argmax(y_hat, dim=1)
            total_correct += sum(torch.eq(y_pre, y))    # 統計目前epoch下的正确個數

            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        acc = (float(total_correct) / train_data.__len__())*100
        save_path = "model/net" + str(epoch+1) + ".pth"
        torch.save(obj=net.state_dict(), f=save_path)
        print("epoch:", str(epoch + 1) + "/20",
              " \n time:", "%.1f" % (time.time() - strat) + "s"    
              " train_loss:", loss.item(),
              " acc:%.3f%%" % acc,)

    print("we are done!")
           
Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)

​ ​ 【代碼解析】

  • total_correct變量用于統計每個epoch下正确預測值的個數,每進行epoch進行一次清零
  • torch.argmax(y_hat, dim=1)用于選取y_hat下每一行的最大值(每個樣本的最高得分),并傳回與y相同次元的tensor
  • torch.eq(y_pre, y)用于比較兩個矩陣元素是否相同,相同則傳回True,不同則傳回False,用于判斷預測值與真實值是否相同
  • torch.save儲存了每個epoch的網絡權重參數

4. 測試集預測部分

# 測試模型,測試集為test_data

import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import Net

test_data = datasets.MNIST(root='data\\',
                           train=False,
                           transform=transforms.ToTensor(),
                           download=True)
test_loader = DataLoader(dataset=test_data,
                          batch_size=100,
                          shuffle=True,
                          drop_last=False,
                          num_workers=4)

if __name__ == "__main__":
    print("---------------預測分析---------------")
    print("測試集樣本:", test_data.__len__(), test_data.data.shape)
    model = Net()
    model.load_state_dict(torch.load("model/net20.pth"))
    model.eval()

    total_correct = 0
    for x, y in test_loader:
        y_hat = model(x)
        y_pre = torch.argmax(y_hat, dim=1)
        total_correct += sum(torch.eq(y_pre, y))

    acc = (float(total_correct) / test_data.__len__())*100
    print("total_test_samples:", test_data.__len__(),
          " test_acc:", "%.3f%%" % acc)
           
Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)

​ ​ 經過20個epoch的訓練,在測試集上達到了98.590%的準确率,部分batch真實值與預測值展示如下:

Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)
Pytorch CNN網絡MNIST數字識别 [超詳細記錄] 學習筆記(三)

5. 全部代碼

連結:連結:https://pan.baidu.com/s/1GGhG1Slw2Tlsgl13yzHUIw

提取碼:82l4

轉載請說明出處