天天看點

手把手教你利用PyTorch實作圖像識别1.softmax回歸2.圖像分類資料集(Fashion-MNIST)3. 使用pytorch實作softmax回歸模型

本篇文章完整代碼可以在我的公衆号【拇指筆記】背景回複"softmax_py"擷取

文末有二維碼~

識别效果:

手把手教你利用PyTorch實作圖像識别1.softmax回歸2.圖像分類資料集(Fashion-MNIST)3. 使用pytorch實作softmax回歸模型

1.softmax回歸

這一部分分為softmax回歸模型的概念、圖像分類資料集的概念、softmax回歸模型的實作和softmax回歸模型基于pytorch架構的實作四部分。

對于離散值預測問題,我們可以使用諸如softmax回歸這樣的分類模型。softmax回歸模型有多個輸出單元。本章以softmax回歸模型為例,介紹神經網絡中的分類模型。

1.1分類問題

例如一個簡單的圖像分類問題,輸入圖形高和寬都為2像素,且色彩為灰階(灰階圖像的像素值可以用一個标量來表示)。我們将圖像的四個像素值記為x1,x2,x3,x4。假設訓練資料集中圖像的真實标簽為狗 貓和雞,這些标簽分别對應着離散值y1,y2,y3。

我們通常使用離散值來表示類别,例如y1=1,y2=2,y3=3。一張圖像的标簽為1、2和3的數值中的一個,對于這種問題,我們一般使用更加适合離散輸出的模型來解決分類問題。

1.2softmax回歸模型

softmax回歸模型一樣将輸入特征與權重做線性疊加。于線性回歸的主要差別為softmax回歸的輸出值個數等于标簽裡的類别數。

在上面的例子中,每個圖像又四個像素,對應着每個圖象有四個特征值(x),有三種可能的動物類别,對應着三

個離散值标簽(o)。是以包含12個權重(w)和3個偏差(b)

o 1 = w 11 x 1 + w 21 x 2 + w 31 x 3 + w 41 x 4 + b 1 , o 2 = w 12 x 1 + w 22 x 2 + w 32 x 3 + w 42 x 4 + b 2 , o 3 = w 13 x 1 + w 23 x 2 + w 33 x 3 + w 43 x 4 + b 3 , w 下 标 命 名 規 則 : 不 同 列 代 表 不 同 輸 出 類 型 , 不 同 行 代 表 不 同 像 素 點 。 列 數 代 表 真 實 輸 出 的 類 别 數 ; 行 數 代 表 特 征 數 。 o_1=w_{11}x_1+w_{21}x_2+w_{31}x_3+w_{41}x_4+b_1, \\o_2=w_{12}x_1+w_{22}x_2+w_{32}x_3+w_{42}x_4+b_2, \\o_3=w_{13}x_1+w_{23}x_2+w_{33}x_3+w_{43}x_4+b_3, \\w下标命名規則: \\不同列代表不同輸出類型,不同行代表不同像素點。 \\列數代表真實輸出的類别數;行數代表特征數。 o1​=w11​x1​+w21​x2​+w31​x3​+w41​x4​+b1​,o2​=w12​x1​+w22​x2​+w32​x3​+w42​x4​+b2​,o3​=w13​x1​+w23​x2​+w33​x3​+w43​x4​+b3​,w下标命名規則:不同列代表不同輸出類型,不同行代表不同像素點。列數代表真實輸出的類别數;行數代表特征數。

softmax回歸也是一個單層神經網絡,每個輸出o的計算都要依賴所有的輸入x,是以softmax回歸的輸出層也是一個全連接配接層。

手把手教你利用PyTorch實作圖像識别1.softmax回歸2.圖像分類資料集(Fashion-MNIST)3. 使用pytorch實作softmax回歸模型

通常将輸出值 oi 作為預測類别 i 的置信度,并将值最大的輸出所對應的類作為預測輸出即

a r g i m a x o i arg_imaxo_i argi​maxoi​

例如o1,o2,o3分别為0.1,10,0.1由于o2最大,那麼預測類别為2。

但這種方法有兩個問題

  1. 輸出層的輸出值的範圍不确定,難以隻管判斷這些值的意義

    如:三個值為0.1,10,0.1時,10代表很置信;但當三個值為1000,10,1000時,10又代表不置信。

  2. 由于真實标簽也是離散值,這些離散值于不确定範圍的輸出值之間的誤差難以衡量。

softmax運算符解決了以上兩個問題。它通過下式将輸出值轉化為值為正且和為1的機率分布。

y 1 ^ , y 2 ^ , y 3 ^ = s o f t m a x ( o 1 , o 2 , o 3 ) \hat{y_1},\hat{y_2},\hat{y_3}=softmax(o_1,o_2,o_3) y1​^​,y2​^​,y3​^​=softmax(o1​,o2​,o3​)

其中

y 1 ^ = e x p ( 0 1 ) ∑ i = 1 3 e x p ( x i ) ,    y 2 ^ = e x p ( 0 2 ) ∑ i = 1 3 e x p ( x i ) ,    y 3 ^ = e x p ( 0 3 ) ∑ i = 1 3 e x p ( x i ) \hat{y_1}=\frac{exp(0_1)}{\sum_{i=1}^3exp(xi)},\ \ \hat{y_2}=\frac{exp(0_2)}{\sum_{i=1}^3exp(xi)},\ \ \hat{y_3}=\frac{exp(0_3)}{\sum_{i=1}^3exp(xi)} y1​^​=∑i=13​exp(xi)exp(01​)​,  y2​^​=∑i=13​exp(xi)exp(02​)​,  y3​^​=∑i=13​exp(xi)exp(03​)​

非常容易看出

y 1 ^ + y 2 ^ + y 3 ^ = 1 且 0 ≤ y 1 ^ , y 2 ^ , y 3 ^ ≤ 1 \hat{y_1}+\hat{y_2}+\hat{y_3}=1 \\且0\leq\hat{y_1},\hat{y_2},\hat{y_3}\leq1 y1​^​+y2​^​+y3​^​=1且0≤y1​^​,y2​^​,y3​^​≤1

基于上兩式可知,y1,y2,y3是合法的機率分布。例如:y2=0.8那麼不管y1,y3是多少,我們都知道為第二個類别的機率為80%

由于

a r g i m a x o i = a r g i m a x y i ^ arg_imaxo_i = arg_imax\hat{y_i} argi​maxoi​=argi​maxyi​^​

可以知道,softmax運算不改變預測類别輸出。

1.3單樣本分類的矢量計算表達式

為了提高運算效率,采用矢量計算。以上面的圖像分類問題為例權重和偏差參數的矢量表達式為

W = { w 11   w 12   w 13 w 21   w 22   w 23 w 31   w 32   w 33 w 41   w 42   w 43 } ,    b = [ b 1   b 2   b 3 ] W = \left\{ \begin{matrix} w_{11}\ w_{12} \ w_{13} \\w_{21}\ w_{22} \ w_{23} \\w_{31}\ w_{32} \ w_{33} \\w_{41}\ w_{42} \ w_{43} \end{matrix} \right\} ,\ \ b=[b_1 \ b_2\ b_3] W=⎩⎪⎪⎨⎪⎪⎧​w11​ w12​ w13​w21​ w22​ w23​w31​ w32​ w33​w41​ w42​ w43​​⎭⎪⎪⎬⎪⎪⎫​,  b=[b1​ b2​ b3​]

設高和寬分别為2個像素的圖像樣本 i 的特征為

x ( i ) = [ x 1 ( i )   x 2 ( i )   x 3 ( i )   x 4 ( i ) ] x^{(i)}=[x^{(i)}_1 \ x^{(i)}_2 \ x^{(i)}_3 \ x^{(i)}_4] x(i)=[x1(i)​ x2(i)​ x3(i)​ x4(i)​]

輸出層輸出為

o i = [ o 1 i   o 2 i   o 3 i ] o^{i} = [o_1^{i} \ o_2^{i} \ o_3^{i}] oi=[o1i​ o2i​ o3i​]

預測的機率分布為

y ^ ( i ) = [ y ^ 1 ( i )   y ^ 2 ( i )   y ^ 3 ( i ) ] \hat{y}^{(i)}=[\hat{y}^{(i)}_1 \ \hat{y}^{(i)}_2 \ \hat{y}^{(i)}_3] y^​(i)=[y^​1(i)​ y^​2(i)​ y^​3(i)​]

最終得到softmax回歸對樣本 i 分類的矢量計算表達式為

o ( i ) = x ( i ) W + b y ^ ( i ) = s o f t m a x ( o ( i ) ) o^{(i)}=x^{(i)}W+b \\ \hat{y}^{(i)}=softmax(o^{(i)}) o(i)=x(i)W+by^​(i)=softmax(o(i))

對于給定的小批量樣本,存在

O = X W + b Y ^ = s o f t m a x ( O ) O = XW+b \\\hat{Y}=softmax(O) O=XW+bY^=softmax(O)

1.4交叉熵損失函數

使用softmax運算後可以更友善地于離散标簽計算誤差。真實标簽同樣可以變換為一個合法的機率分布,即:對于一個樣本(一個圖像),它的真實類别為y_i,我們就令y_i為1,其餘為0。如圖像為貓(第二個),則它的y = [0 1 0 ]。這樣就可以使\hat{y}更接近y。

在圖像分類問題中,想要預測結果正确并不需要讓預測機率與标簽機率相等(不同動作 顔色的貓),我們隻需要讓真實類别對應的機率大于其他類别的機率即可,是以不必使用線性回歸模型中的平方損失函數。

我們使用交叉熵函數來計算損失。

H ( y ( i ) , y ^ ( i ) ) = − ∑ j = 1 q y j ( i ) l o g   y ^ j ( i ) H(y^{(i)},\hat{y}^{(i)})=-\sum_{j=1}^q y_j^{(i)}log\ \hat{y}^{(i)}_j H(y(i),y^​(i))=−j=1∑q​yj(i)​log y^​j(i)​

這個式子中,y^(i) _j 是真實标簽機率中的為1的那個元素,而 \hat{y}^{(i)}_j 是預測得到的類别機率中與之對應的那個元素。

由于在y(i)中隻有一個标簽,是以在y{i}中,除了y^(i) _j 外,其餘元素都為0,于是得到上式的簡化方程

H ( y ( i ) , y ^ ( i ) ) = − l o g   y ^ j ( i ) H(y^{(i)},\hat{y}^{(i)}) =- log\ \hat{y}^{(i)}_j H(y(i),y^​(i))=−log y^​j(i)​

也就是說交叉熵函數隻與預測到的機率數有關,隻要預測得到的值夠大,就可以確定分類結果的正确性。

對于整體樣本而言,交叉熵損失函數定義為

l ( θ ) = 1 n ∑ i = 1 n H ( y ( i ) , y ^ ( i ) ) l(\theta) =\frac{1}{n} \sum_{i=1}^n H(y^{(i)},\hat{y}^{(i)}) l(θ)=n1​i=1∑n​H(y(i),y^​(i))

其中\theta代表模型參數,如果每個樣本都隻有一個标簽,則上式可以簡化為

l ( θ ) = − 1 n ∑ i = 1 n l o g   y ^ j ( i ) l(\theta) =-\frac{1}{n} \sum_{i=1}^nlog\ \hat{y}^{(i)}_j l(θ)=−n1​i=1∑n​log y^​j(i)​

最小化交叉熵損失函數等價于最大化訓練資料集所有标簽類别的聯合預測機率 。

2.圖像分類資料集(Fashion-MNIST)

這一章節需要用到torchvision包,為此,我重裝了

這個資料集是我們在後面學習中将會用到的圖形分類資料集。它的圖像内容相較于手寫數字識别資料集MINIST更為複雜一些,更加便于我們直覺的觀察算法之間的差異。

這一節主要使用torchvision包,主要用來建構計算機視覺模型。

torchvision包的主要構成 功能
torchvision.datasets 一些加載資料的函數及常用資料集接口
torchvision.madels 包含常用的模型結構(含預訓練模型)
torchvision.transforms 常用的圖檔變換(裁剪、旋轉)
torchvision.utils 其他方法

2.1擷取資料集

首先導入需要的包

import torch 
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..")	
#調用庫時,sys.path會自動搜尋路徑,為了導入d2l這個庫,是以需要添加".."
#import d2lzh_pytorch as d2l	這個庫找不到不用了
from IPython import display
#在這一節d2l庫僅僅在繪圖時被使用,是以使用這個庫做替代
           

**通過調用torchvision中的torchvision.datasets來下載下傳這個資料集。**第一次調用從網上自動擷取資料。

通過設定參數train來制定擷取訓練資料集或測試資料集(測試集:用來評估模型表現,并不用來訓練模型)。

通過設定參數transfrom = transforms.ToTensor()将所有資料轉換成Tensor,如果不進行轉換則傳回PIL圖檔。

transforms.ToTensor()函數将尺寸為(H*W*C)且資料位于[0,255]之間的PIL圖檔或者資料類型為np.uint8的NumPy數組轉換為尺寸為(C*H*W)且資料類型為torch.float32且位于[0,0,1.0]的Tensor

C代表通道數,灰階圖像的通道數為1

PIL圖檔是python處理圖檔的标準

注意:transforms.ToTensor()函數預設将輸入類型設定為uint8

#擷取訓練集
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download = True,transform = transforms.ToTensor())
#擷取測試集
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download = True,transform = transforms.ToTensor())
           

其中mnist_train和mnist_test可以用len()來擷取該資料集的大小,還可以用下标來擷取具體的一個樣本。

訓練集和測試集都有10個類别,訓練集中每個類别的圖像數為6000,測試集中每個類别的圖像數為1000,即:訓練集中有60000個樣本,測試集中有10000個樣本。

len(mnist_train)	#輸出訓練集的樣本數
mnist_train[0]		#通過下标通路任意一個樣本,傳回值為兩個torch,一個特征tensor和一個标簽tensor
           

Fashion-MNIST資料集中共有十個類别,分别為: t-shirt(T恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和ankle boot(短靴) 。

需要将這些文本标簽和數值标簽互相轉換,可以通過以下函數進行。

def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
	#labels是一個清單
	#數值标簽轉文本标簽
           

下面是一個可以在意行裡畫出多張圖像和對應标簽的函數

def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
   	#繪制矢量圖
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    #建立子圖,一行len(images)列,圖檔大小12*12
    for f, img, lbl in zip(figs, images, labels):
        #zip函數将他們壓縮成由多個元組組成的清單
        f.imshow(img.view((28, 28)).numpy())
        #将img轉形為28*28大小的張量,然後轉換成numpy數組
        f.set_title(lbl)
        #設定每個子圖的标題為标簽
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
        #關閉x軸y軸
    plt.show()
           

上述函數的使用

X,y = [],[]
#初始化兩個清單
for i in range(10):
	X.append(mnist_train[i][0])
	#循環向X清單添加圖像
	y.append(mnist_train[i][1])
	#循環向y清單添加标簽
show_fashion_mnist(X,get_fashion_mnist_labels(y))
#顯示圖像和清單
           

2.2讀取小批量

有了線性回歸中讀取小批量的經驗,我們知道讀取小批量可以使用torch中内置的dataloader函數來實作。

dataloader還支援多線程讀取資料,通過設定它的num_workers參數。

batch_size = 256
#小批量數目
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle = True,num_workers = 0)
#num_workers=0,不開啟多線程讀取。
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size = batch_size,shuffle=False,num_workers=0)
           

3. 使用pytorch實作softmax回歸模型

使用pytorch可以更加便利的實作softmax回歸模型。

3.1 擷取和讀取資料

讀取小批量資料的方法:

  1. 首先是擷取資料,pytorch可以通過以下代碼很友善的擷取Fashion-MNIST資料集。
    mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())
    
    mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())
    
    #參數
    
    #root : processed/training.pt 和 processed/test.pt 的主目錄 
    #train : True = 訓練集, False = 測試集
    #download : True = 從網際網路上下載下傳資料集,并把資料集放在root目錄下. 如果資料集之前下載下傳過,将處理過的資料(minist.py中有相關函數)放在processed檔案夾下
    #transform = transforms.ToTensor():使所有資料轉換為Tensor
               
  2. 然後是生成一個疊代器,用來讀取資料
    #生成疊代器
    train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle = True,num_workers = 0)
    
    test_iter = torch.utils.data.DataLoader(mnist_test,batch_size = batch_size,shuffle=False,num_workers=0)
    #參數
    
    #dataset:Dataset類型,從其中加載資料
    #batch_size:int類型,每個批量加載多少個數
    #shuffle:bool類型,每個學習周期都打亂順序
    #num_workers:int類型,加載資料時使用多少子程序。預設值為0.
    #collate_fn:定義如何取樣本,可通過定義自己的函數來實作。
    #pin_memory:鎖頁記憶體處理。
    #drop_last:bool類型,如果有剩餘的樣本,True表示丢棄;Flase表示不丢棄
               

3.2 定義和初始化模型

由softmax回歸模型的定義可知,softmax回歸模型隻有權重參數和偏差參數。是以可以使用神經網絡子子產品中的線性子產品。

o 1 = w 11 x 1 + w 21 x 2 + w 31 x 3 + w 41 x 4 + b 1 , o 2 = w 12 x 1 + w 22 x 2 + w 32 x 3 + w 42 x 4 + b 2 , o 3 = w 13 x 1 + w 23 x 2 + w 33 x 3 + w 43 x 4 + b 3 , o_1=w_{11}x_1+w_{21}x_2+w_{31}x_3+w_{41}x_4+b_1, \\o_2=w_{12}x_1+w_{22}x_2+w_{32}x_3+w_{42}x_4+b_2, \\o_3=w_{13}x_1+w_{23}x_2+w_{33}x_3+w_{43}x_4+b_3, o1​=w11​x1​+w21​x2​+w31​x3​+w41​x4​+b1​,o2​=w12​x1​+w22​x2​+w32​x3​+w42​x4​+b2​,o3​=w13​x1​+w23​x2​+w33​x3​+w43​x4​+b3​,

  1. 首先定義網絡,softmax回歸是一個兩層的網絡,是以隻需要定義輸入層和輸出層即可。
num_inputs = 784
num_outputs = 10

class LinearNet(nn.Module):
    def __init__(self,num_inputs,num_outputs):
        super(LinearNet,self).__init__()
        self.linear = nn.Linear(num_inputs,num_outputs)
        #定義一個輸入層
        
    #定義向前傳播(在這個兩層網絡中,它也是輸出層)
    def forward(self,x):
        y = self.linear(x.view(x.shape[0],-1))
        #将x換形為y後,再繼續向前傳播
        return y
    
net = LinearNet(num_inputs,num_outputs)
           
  1. 初始化參數

使用torch.nn中的init可以快速的初始化參數。我們令權重參數為均值為0,标準差為0.01的正态分布。偏差為0。

init.normal_(net.linear.weight, mean=0, std=0.01)
init.constant_(net.linear.bias, val=0) 
           

3.3 softmax運算和交叉熵損失函數

分開定義softmax運算和交叉熵損失函數會造成數值不穩定。是以PyTorch提供了一個具有良好數值穩定性且包括softmax運算和交叉熵計算的函數。

3.4 定義優化算法

依然使用小批量随機梯度下降作為優化算法。定義學習率為0.1。

3.5 計算分類準确率

計算準确率的原理:

我們把預測機率最大的類别作為輸出類别,如果它與真實類别y一緻,說明預測正确。分類準确率就是正确預測數量與總預測數量之比

首先我們需要得到預測的結果。

從一組預測機率(變量y_hat)中找出最大的機率對應的索引(索引即代表了類别)

#argmax(f(x))函數,對f(x)求最大值所對應的點x。我們令f(x)= dim=1,即可實作求所有行上的最大值對應的索引。
A = y_hat.argmax(dim=1)	
#最終輸出結果為一個行數與y_hat相同的列向量
           

然後我們需要将得到的最大機率對應的類别與真實類别(y)比較,判斷預測是否是正确的

B = (y_hat.argmax(dim=1)==y).float()
#由于y_hat.argmax(dim=1)==y得到的是ByteTensor型資料,是以我們通過.float()将其轉換為浮點型Tensor()
           

最後我們需要計算分類準确率

我們知道y_hat的行數就對應着樣本總數,是以,對B求平均值得到的就是分類準确率

上一步最終得到的資料為tensor(x)的形式,為了得到最終的pytorch number,需要對其進行下一步操作

(y_hat.argmax(dim=1)==y).float().mean().item()
#pytorch number的擷取統一通過.item()實作
           

整理一下,得到計算分類準确率函數

def accuracy(y_hat,y):
    return (y_hat.argmax(dim=1).float().mean().item())
           

作為推廣,該函數還可以評價模型net在資料集data_iter上的準确率。

def net_accurary(data_iter,net):
    right_sum,n = 0.0,0
    for X,y in data_iter:
    #從疊代器data_iter中擷取X和y
        right_sum += (net(X).argmax(dim=1)==y).float().sum().item()
        #計算準确判斷的數量
        n +=y.shape[0]
        #通過shape[0]擷取y的零次元(列)的元素數量
    return right_sum/n
           

3.6 訓練模型

num_epochs = 5
#一共進行五個學習周期

def train_softmax(net,train_iter,test_iter,loss,num_epochs,batch_size,optimizer,net_accurary):
    for epoch in range(num_epochs):
        #損失值、正确數量、總數 初始化。
        train_l_sum,train_right_sum,n= 0.0,0.0,0
        
        for X,y in train_iter:
            y_hat = net(X)
            l = loss(y_hat,y).sum()
            #資料集損失函數的值=每個樣本的損失函數值的和。            
            optimizer.zero_grad()			#對優化函數梯度清零
            l.backward()	#對損失函數求梯度
            optimizer(params,lr,batch_size)
            
            train_l_sum += l.item()
            train_right_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
            
        test_acc = net_accurary(test_iter, net)	#測試集的準确率
        print('第%d學習周期, 誤差%.4f, 訓練準确率%.3f, 測試準确率%.3f' % (epoch + 1, train_l_sum / n, train_right_sum / n, test_acc))
        
train_softmax(net,train_iter,test_iter,cross_entropy,num_epochs,batch_size,optimizernet_accurary,net_accurary)
           

訓練效果

手把手教你利用PyTorch實作圖像識别1.softmax回歸2.圖像分類資料集(Fashion-MNIST)3. 使用pytorch實作softmax回歸模型

3.7 圖像分類

使用訓練好的模型對測試集進行預測

做一個模型的最終目的當然不是訓練了,是以來預測一下試試。

#将樣本的類别數字轉換成文本
def get_Fashion_MNIST_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
    #labels是一個清單,是以有了for循環擷取這個清單對應的文本清單

#顯示圖像
def show_fashion_mnist(images,labels):
    display.set_matplotlib_formats('svg')
    #繪制矢量圖
    _,figs = plt.subplots(1,len(images),figsize=(12,12))
    #設定添加子圖的數量、大小
    for f,img,lbl in zip(figs,images,labels):
        f.imshow(img.view(28,28).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

#從測試集中獲得樣本和标簽
X, y = iter(test_iter).next()

true_labels = get_Fashion_MNIST_labels(y.numpy())
pred_labels = get_Fashion_MNIST_labels(net(X).argmax(dim=1).numpy())

#将真實标簽和預測得到的标簽加入到圖像上
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

show_fashion_mnist(X[0:9], titles[0:9])

           

實作效果

第一行是真實标簽,第二行是識别标簽
手把手教你利用PyTorch實作圖像識别1.softmax回歸2.圖像分類資料集(Fashion-MNIST)3. 使用pytorch實作softmax回歸模型

寫文章不易,如果覺得有用,麻煩關注我呗~

歡迎各位關注【拇指筆記】,每天更新我的學習筆記~

手把手教你利用PyTorch實作圖像識别1.softmax回歸2.圖像分類資料集(Fashion-MNIST)3. 使用pytorch實作softmax回歸模型

繼續閱讀