天天看點

基于PyTorch的卷積神經網絡圖像分類——貓狗大戰(一):使用Pytorch定義DataLoader1. 需要用到的庫2. 資料擴充定義3. 自定義Dataset4. 測試

目錄

1. 需要用到的庫

2. 資料擴充定義

3. 自定義Dataset

4. 測試

         開始一個新的系列,基于Kaggle比賽的貓狗大戰資料集,基于PyTorch實作貓狗圖像分類。

         如何定義網絡模型見:https://blog.csdn.net/linghu8812/article/details/119147899

         資料集位址在:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/overview。

         下面是第一部分,主要介紹如何使用Pytorch自定義Dataloader。

1. 需要用到的庫

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
           

2. 資料擴充定義

image_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
           

資料擴充主要分為以下幾步:

1)将圖像的短邊resize到256;

2)然後随即裁減224x224;

3)再進行随機水準翻轉;

4)最後将圖像轉為Tensor并且标準化。

3. 自定義Dataset

class DogVsCatDataset(Dataset):
    """Dog vs Cat dataset."""

    def __init__(self, root_dir, train=True, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.img_path = os.listdir(self.root_dir)
        if train:
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path))
        else:
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))
        self.transform = transform

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))
        label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array([label]))
        return image, label
           

資料集初始化時要設定圖檔目錄;是否是訓練集或者是驗證集,圖檔編号小于10000的為訓練集,大于等于10000的為驗證集;及資料擴充方式;貓的标簽為0,狗的标簽為1。

4. 測試

if __name__ == '__main__':
    catanddog_dataset = DogVsCatDataset(root_dir='../dogs-vs-cats-redux-kernels-edition/train', train=False,
                                        transform=image_transform)
    train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)
    image, label = iter(train_loader).next()
    sample = image[0].squeeze()
    sample = sample.permute((1, 2, 0)).numpy()
    sample *= [0.229, 0.224, 0.225]
    sample += [0.485, 0.456, 0.406]
    plt.imshow(sample)
    plt.show()
    print('Label is: {}'.format(label[0].numpy()))
           

測試的時候使用“if __name__ == '__main__':”可以在其他檔案import時,不執行這些語句。執行代碼後,顯示的圖檔和列印的标簽如下所示:

基于PyTorch的卷積神經網絡圖像分類——貓狗大戰(一):使用Pytorch定義DataLoader1. 需要用到的庫2. 資料擴充定義3. 自定義Dataset4. 測試

Label is: [0]

基于PyTorch的卷積神經網絡圖像分類——貓狗大戰(一):使用Pytorch定義DataLoader1. 需要用到的庫2. 資料擴充定義3. 自定義Dataset4. 測試

Label is: [1]

繼續閱讀