目錄
1. 需要用到的庫
2. 資料擴充定義
3. 自定義Dataset
4. 測試
開始一個新的系列,基于Kaggle比賽的貓狗大戰資料集,基于PyTorch實作貓狗圖像分類。
如何定義網絡模型見:https://blog.csdn.net/linghu8812/article/details/119147899
資料集位址在:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/overview。
下面是第一部分,主要介紹如何使用Pytorch自定義Dataloader。
1. 需要用到的庫
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
2. 資料擴充定義
image_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
資料擴充主要分為以下幾步:
1)将圖像的短邊resize到256;
2)然後随即裁減224x224;
3)再進行随機水準翻轉;
4)最後将圖像轉為Tensor并且标準化。
3. 自定義Dataset
class DogVsCatDataset(Dataset):
"""Dog vs Cat dataset."""
def __init__(self, root_dir, train=True, transform=None):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.root_dir = root_dir
self.img_path = os.listdir(self.root_dir)
if train:
self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path))
else:
self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))
self.transform = transform
def __len__(self):
return len(self.img_path)
def __getitem__(self, idx):
image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))
label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1
if self.transform:
image = self.transform(image)
label = torch.from_numpy(np.array([label]))
return image, label
資料集初始化時要設定圖檔目錄;是否是訓練集或者是驗證集,圖檔編号小于10000的為訓練集,大于等于10000的為驗證集;及資料擴充方式;貓的标簽為0,狗的标簽為1。
4. 測試
if __name__ == '__main__':
catanddog_dataset = DogVsCatDataset(root_dir='../dogs-vs-cats-redux-kernels-edition/train', train=False,
transform=image_transform)
train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)
image, label = iter(train_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))
測試的時候使用“if __name__ == '__main__':”可以在其他檔案import時,不執行這些語句。執行代碼後,顯示的圖檔和列印的标簽如下所示:

Label is: [0]
Label is: [1]