天天看點

Pytorch入門學習:資料加載和預處理的通用方法

轉載來源:CSDN

原文:https://blog.csdn.net/Hungryof/article/details/76649006

torchvision的主要用途。

兩種資料集:

  1. 所有圖檔都在同一個檔案夾内。(這個用 torch.utils.data.DataSet類就行!)
  2. 不同類别的圖檔放在不同的檔案夾。(用 torchvision.datasets.ImageFolder(‘image_dir_root’ )

大部分任務的資料都是第一種吧,第二種一般是分類任務,比如imagenet資料集有1000類,對應1000個檔案夾。

目錄結構如下:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png

.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
           

注意:

torchvision包的三個用途:

  1. 提供流行的model,同時可以針對常用資料集直接進行處理。
  2. 還針對torch.utils.data.Dataset進行了擴充,主要就是有了針對這種不同類别圖檔放入不同檔案夾的資料進行讀取,torchvision.datasets.ImageFolder是torch.utils.data.Dataset的子類!都傳回一個疊代器。
  3. 提供現成的torchvision.transforms ,進而避免自己寫的麻煩。

兩種讀取方法

一般用到:

  1. torch.utils.data.Dataset(這是底層的),或是繼承自它的自定義類,或是繼承自它的 torchvision.data.ImageFolder.
  2. 對于1讀取的圖檔,進行 torchvison.transforms來變換一下。
  3. 對于2傳回的疊代器,用 torch.utils.data.DataLoader用多線程讀取。

讀取流程示意

  1. 自定義dataset類, 它是最底層的。重載 torch.utils.data.Dataset。至少重載三個函數:

    init, getitem__以及__len.

    這個主要負責從資料庫中讀取圖檔,但是我們讀取的圖檔可能要經過各種變換,放縮之類的。是以在_init__中可以把變換操作名稱傳入,在_getitem 中先load圖檔,然後在img_transformed = self.transforms(img)。其中self.transforms是__init__傳入的參數。

  2. 将torchvision.transforms.Compose函數作為參數,往自定義dataset類裡面傳
  3. 将2傳回的疊代器,用 torch.utils.data.DataLoader多線程讀取

使用 torch.utils.data.Dataset針對 All images in One Folder

以官方例子 super_resolution為例:

首先在main中

train_set = get_training_set(opt.upscale_factor)
test_set = get_test_set(opt.upscale_factor)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)
           

然後看 get_training_set,追蹤到data.py,該腳本主要是對資料進行下載下傳解壓,以及

from os.path import exists, join, basename
from os import makedirs, remove
from six.moves import urllib
import tarfile
from torchvision.transforms import Compose, CenterCrop, ToTensor, Scale

from dataset import DatasetFromFolder


def download_bsd300(dest="dataset"):
    output_image_dir = join(dest, "BSDS300/images")

    if not exists(output_image_dir):
        makedirs(dest)
        url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
        print("downloading url ", url)

        data = urllib.request.urlopen(url)

        file_path = join(dest, basename(url))
        with open(file_path, 'wb') as f:
            f.write(data.read())

        print("Extracting data")
        with tarfile.open(file_path) as tar:
            for item in tar:
                tar.extract(item, dest)

        remove(file_path)

    return output_image_dir


def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def input_transform(crop_size, upscale_factor):
    return Compose([
        CenterCrop(crop_size),
        Scale(crop_size // upscale_factor),
        ToTensor(),
    ])


def target_transform(crop_size):
    return Compose([
        CenterCrop(crop_size),
        ToTensor(),
    ])

           

看到這裡開始調用自定義dataset類!

def get_training_set(upscale_factor):
    root_dir = download_bsd300()
    train_dir = join(root_dir, "train")
    crop_size = calculate_valid_crop_size(256, upscale_factor)
           

自定義dataset類,傳入參數是 transforms。可以看到這是将函數input_transform作為

參數傳進自定義類。

return DatasetFromFolder(train_dir,
                         input_transform=input_transform(crop_size, upscale_factor),
                         target_transform=target_transform(crop_size))
           
def get_test_set(upscale_factor):
    root_dir = download_bsd300()
    test_dir = join(root_dir, "test")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return DatasetFromFolder(test_dir,
                             input_transform=input_transform(crop_size, upscale_factor),
                             target_transform=target_transform(crop_size))
           

再找到 dataset.py, 這裡開始自定義dataset類。

import torch.utils.data as data

from os import listdir
from os.path import join
from PIL import Image


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])


def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()
    return y


class DatasetFromFolder(data.Dataset):

    def __init__(self, image_dir, input_transform=None, target_transform=None):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]

        self.input_transform = input_transform
        self.target_transform = target_transform
           

在__getitem__中加載圖檔,并且将傳入的transformation操作運用到

加載的圖檔中。

input = self.input_transforms(input)

這裡的 self.input_transforms就是傳入的"類的執行個體",由于類是callable的

是以可以 "類的執行個體(參數)"這樣調用。在上一篇部落格說到了這個。

def __getitem__(self, index):
    input = load_img(self.image_filenames[index])
    target = input.copy()
    if self.input_transform:
        input = self.input_transform(input)
    if self.target_transform:
        target = self.target_transform(target)

    return input, target

def __len__(self):
    return len(self.image_filenames)
           

看看torchvision.data.MNIST内部

class MNIST(data.Dataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = [
        'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
           

可以看到,這裡也是用

img = self.transform(img)

方式的。

def __getitem__(self, index):
    """
    Args:
        index (int): Index
    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    if self.train:
        img, target = self.train_data[index], self.train_labels[index]
    else:
        img, target = self.test_data[index], self.test_labels[index]

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img.numpy(), mode='L')

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

def __len__(self):
    if self.train:
        return len(self.train_data)
    else:
        return len(self.test_data)
           

使用 torchvision.data.ImageFolder針對 One kind of images in One kind of Folder

比如imagenet的代碼:

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
           

第一,二步

用ImageFolder來讀取dataset

train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
           

第三步

DataLoader多線程讀取

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
    num_workers=args.workers, pin_memory=True, sampler=train_sampler)

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)
           

繼續閱讀