版權聲明:本文為部落客原創文章,未經部落客允許不得轉載。 https://blog.csdn.net/Teeyohuang/article/details/79222857
pytorch進行CIFAR-10分類(5)測試
我的系列博文:
Pytorch打怪路(一)pytorch進行CIFAR-10分類(1)CIFAR-10資料加載和處理
Pytorch打怪路(一)pytorch進行CIFAR-10分類(2)定義卷積神經網絡
Pytorch打怪路(一)pytorch進行CIFAR-10分類(3)定義損失函數和優化器
Pytorch打怪路(一)pytorch進行CIFAR-10分類(4)訓練
Pytorch打怪路(一)pytorch進行CIFAR-10分類(5)測試(本文)
1.直接上代碼
代碼第一部分
dataiter = iter(testloader) # 建立一個python疊代器,讀入的是我們第一步裡面就已經加載好的testloader
images, labels = dataiter.next() # 傳回一個batch_size的圖檔,根據第一步的設定,應該是4張
# print images
imshow(torchvision.utils.make_grid(images)) # 展示這四張圖檔
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4))) # python字元串格式化 ' '.join表示用空格來連接配接後面的字元串,參考python的join()方法
複制
複制
這一部分代碼就是先随機讀取4張圖檔,讓我們看看這四張圖檔是什麼并列印出相應的label資訊,
因為第一步裡面設定了是shuffle了資料的,也就是順序是打亂的,是以各自出現的圖像不一定相同,
代碼第二部分
outputs = net(Variable(images)) # 注意這裡的images是我們從上面獲得的那四張圖檔,是以首先要轉化成variable
_, predicted = torch.max(outputs.data, 1)
# 這個 _ , predicted是python的一種常用的寫法,表示後面的函數其實會傳回兩個值
# 但是我們對第一個值不感興趣,就寫個_在那裡,把它指派給_就好,我們隻關心第二個值predicted
# 比如 _ ,a = 1,2 這中指派語句在python中是可以通過的,你隻關心後面的等式中的第二個位置的值是多少
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4))) # python的字元串格式化
複制
這裡用到了torch.max( ), 它是屬于Tensor的一個方法:
注意到注釋中第一句話,是說傳回傳回輸入Tensor中每行的最大值,并轉換成指定的dim(次元),
是以我們程式中的 torch.max(outputs.data, 1) ,傳回一個tuple (元組)
而這裡很明顯,這個傳回的元組的第一個元素是image data,即是最大的 值,第二個元素是label, 即是最大的值 的 索引!
我們隻需要label(最大值的索引),是以就會有 _ , predicted這樣的指派語句,表示忽略第一個傳回值,把它指派給 _, 就是舍棄它的意思;
我在注釋中也說明了這是什麼意思
這裡說一下,這第二個參數1,看清楚上面的說明是 the dimension to reduce! 而不是去這個dimension上面找最大
是以這裡dim=1,基于我們的a是 4行 x 4列 這麼一個次元,是以指的是 消除列這個次元,這是個什麼意思呢?
如果我們把上面的示例代碼中,的參數 keepdim=True寫上,torch.max(a,1,keepdim=True), 會發現,傳回的結果的第一個元素,即表示最大的值的那部分,其實是一個 size為 【4,1】的Tensor,也就是其實它是在 按照每行 來找最大,是以結果是4行,然後因為隻找一個最大值,是以是1列,整個size就是 4行 1 列, 然後參數dim=1,相當于調用了 squeeze(1),這個操作,上面的說明也是這麼寫的,是以最後就得到結果是一個size為4的vector。
你可以自己下去在ipython裡面做實驗,發現如果dim=0,它其實是在傳回每列的最大值,
是以一定不要搞混!這裡的dim是指的 the dimension to reduce!并不是在the dimension上去傳回最大值。
是以其實我自己寫的時候一般更喜歡用 torch.argmax()這個函數更直覺更好了解一些
總之在這裡你隻需要了解這行操作的功能是:傳回了最大的索引,即預測出來的類别。 想深入研究可以自己去ipython裡面試一下
代碼第三部分
correct = 0 # 定義預測正确的圖檔數,初始化為0
total = 0 # 總共參與測試的圖檔數,也初始化為0
for data in testloader: # 循環每一個batch
images, labels = data
outputs = net(Variable(images)) # 輸入網絡進行測試
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0) # 更新測試圖檔的數量
correct += (predicted == labels).sum() # 更新正确分類的圖檔的數量
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total)) # 最後列印結果
複制
tutorial給的結果是53%
代碼第四部分
來測試一下每一類的分類正确率
class_correct = list(0. for i in range(10)) # 定義一個存儲每類中測試正确的個數的 清單,初始化為0
class_total = list(0. for i in range(10)) # 定義一個存儲每類中測試總數的個數的 清單,初始化為0
for data in testloader: # 以一個batch為機關進行循環
images, labels = data
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, 1)
c = (predicted == labels).squeeze()
for i in range(4): # 因為每個batch都有4張圖檔,是以還需要一個4的小循環
label = labels[i] # 對各個類的進行各自累加
class_correct[label] += c[i]
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
複制