天天看点

基于PyTorch的卷积神经网络图像分类——猫狗大战(一):使用Pytorch定义DataLoader1. 需要用到的库2. 数据扩充定义3. 自定义Dataset4. 测试

目录

1. 需要用到的库

2. 数据扩充定义

3. 自定义Dataset

4. 测试

         开始一个新的系列,基于Kaggle比赛的猫狗大战数据集,基于PyTorch实现猫狗图像分类。

         如何定义网络模型见:https://blog.csdn.net/linghu8812/article/details/119147899

         数据集地址在:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/overview。

         下面是第一部分,主要介绍如何使用Pytorch自定义Dataloader。

1. 需要用到的库

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
           

2. 数据扩充定义

image_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
           

数据扩充主要分为以下几步:

1)将图像的短边resize到256;

2)然后随即裁减224x224;

3)再进行随机水平翻转;

4)最后将图像转为Tensor并且标准化。

3. 自定义Dataset

class DogVsCatDataset(Dataset):
    """Dog vs Cat dataset."""

    def __init__(self, root_dir, train=True, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.img_path = os.listdir(self.root_dir)
        if train:
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path))
        else:
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))
        self.transform = transform

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))
        label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array([label]))
        return image, label
           

数据集初始化时要设置图片目录;是否是训练集或者是验证集,图片编号小于10000的为训练集,大于等于10000的为验证集;及数据扩充方式;猫的标签为0,狗的标签为1。

4. 测试

if __name__ == '__main__':
    catanddog_dataset = DogVsCatDataset(root_dir='../dogs-vs-cats-redux-kernels-edition/train', train=False,
                                        transform=image_transform)
    train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)
    image, label = iter(train_loader).next()
    sample = image[0].squeeze()
    sample = sample.permute((1, 2, 0)).numpy()
    sample *= [0.229, 0.224, 0.225]
    sample += [0.485, 0.456, 0.406]
    plt.imshow(sample)
    plt.show()
    print('Label is: {}'.format(label[0].numpy()))
           

测试的时候使用“if __name__ == '__main__':”可以在其他文件import时,不执行这些语句。执行代码后,显示的图片和打印的标签如下所示:

基于PyTorch的卷积神经网络图像分类——猫狗大战(一):使用Pytorch定义DataLoader1. 需要用到的库2. 数据扩充定义3. 自定义Dataset4. 测试

Label is: [0]

基于PyTorch的卷积神经网络图像分类——猫狗大战(一):使用Pytorch定义DataLoader1. 需要用到的库2. 数据扩充定义3. 自定义Dataset4. 测试

Label is: [1]

继续阅读