天天看点

Pytorch 04: Pytorch中数据加载---Dataset类和DataLoader类

从代码角度学习理解Pytorch学习框架04: Dataset类和DataLoader类了解,方便我们加载和处理数据。

# coding=utf-8
import matplotlib as mpl
mpl.use('tkagg')   # 调试:agg;  运行: tkagg
import matplotlib.pyplot as plt

import os
import pandas as pd
import torch

"""
torch.utils.data.Dataset 是一个表示数据集的抽象类.
你自己的数据集一般应该继承``Dataset``, 并且重写下面的方法:
    1. __len__ 使用``len(dataset)`` 可以返回数据集的大小
    2. __getitem__ 支持索引, 以便于使用 dataset[i] 可以 获取第i个样本(0索引)
"""
from torch.utils.data import Dataset


"""
torch.utils.data中的DataLoader提供为Dataset类对象提供了:
    1.批量读取数据
    2.打乱数据顺序
    3.使用multiprocessing并行加载数据
    
    DataLoader中的一个参数collate_fn:可以使用它来指定如何精确地读取一批样本,
     merges a list of samples to form a mini-batch.
    然而,默认情况下collate_fn在大部分情况下都表现很好
"""
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from skimage import io, transform
import numpy as np


def just_see_face_dataset():
    """
    摸一下数据
    :return: 
    """
    landmarks_frame = pd.read_csv('./faces/face_landmarks.csv')
    n = 65
    img_name = landmarks_frame.iloc[n, 0]
    landmarks = landmarks_frame.iloc[n, 1:].as_matrix()    # as_matrix()
    landmarks = landmarks.astype('float').reshape(-1, 2)
    print('img_name: {}'.format(img_name))
    print('landmarks shape: {}'.format(landmarks.shape))
    print('first 4 landmarks: {}'.format(landmarks[:4]))

    plt.figure()
    show_landmarks(io.imread(os.path.join('faces', img_name)), landmarks)
    plt.show()


def show_landmarks(image, landmarks):
    """
    显示一张图片和它对应的标记点
    :param image:
    :param landmarks:
    :return:
    """
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)


class FaceLandmarksDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        """
        继承 Dataset 类后,必须重写的一个方法
        返回数据集的大小
        :return:
        """
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        """
        继承 Dataset 类后,必须重写的一个方法
        返回第 idx 个图像及相关信息
        :param idx:
        :return:
        """
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample


def t_dataset():
    """
    测试 FaceLandmarksDataset 类的使用
    :return: 
    """
    # 实列化 FaceLandmarksDataset 类
    face_dataset = FaceLandmarksDataset(csv_file='./faces/face_landmarks.csv', root_dir='./faces')
    fig = plt.figure()
    length_dataset = len(face_dataset)

    for i in range(length_dataset):
        # 注: Dataset 类对象可以直接索引[i]访问
        sample = face_dataset[i]
        print(i, sample['image'].shape, sample['landmarks'].shape)

        ax = plt.subplot(1, 4, i + 1)
        plt.tight_layout()
        ax.set_title('sample #{}'.format(i))
        ax.axis('off')
        show_landmarks(sample['image'], sample['landmarks'])
        if i == 3:
            plt.show()
            break


"""Transform操作"""
class Rescale(object):
    """按照给定尺寸更改一个图像的尺寸
    
    Args:
        output_size (tuple or int): 要求输出的尺寸.  如果是个元组类型, 输出
        和output_size匹配. 如果时int类型,图片的短边和output_size匹配, 图片的
        长宽比保持不变.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # 对于标记点, h和w需要交换位置, 因为对于图像, x和y分别时第1维和第0维
        landmarks = landmarks * [new_w / w, new_h / h]

        # 返回值实际上也是一个sample
        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """随机裁剪图片

    Args:
        output_size (tuple or int): 期望输出的尺寸, 如果时int类型, 裁切成正方形.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        # 返回值实际上也是一个sample
        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """
    将 ndarray 的样本转化为 Tensor 的样本
    """
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # 交换轴,因为 numpy 图片:H x W x C, torch输入图片要求: C x H x W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}


def use_transoform(one_sample):
    """
    演示如何使用 transform: 把几种 transform 组合在一起
    :return: 
    """
    # transforms.Compose 只是将这两种tranform组合在一起,按顺序对sample进行处理
    composed = transforms.Compose([Rescale(256), RandomCrop(224)])
    transfromed_sample = composed(one_sample)
    plt.figure()
    show_landmarks(transfromed_sample['image'], transfromed_sample['landmarks'])
    plt.show()


def union_all_knowledge():
    """
    迭代整个数据集:
        每次迭代数据,都会1.从文件中读取图像    2.对所读取的图像应用上述变换transform。 从而对数据集进行增强操作
    :return: 
    """
    transformed_dataset = FaceLandmarksDataset(csv_file='./faces/face_landmarks.csv', root_dir='./faces',
                                               transform=transforms.Compose([
                                                   Rescale(256),
                                                   RandomCrop(225),
                                                   ToTensor()]))
    for i in range(len(transformed_dataset)):
        sample = transformed_dataset[i]
        print(i, sample['image'].size(), sample['landmarks'].size())
        if i == 3:
            break


def t_dataloader():
    transformed_dataset = FaceLandmarksDataset(csv_file='./faces/face_landmarks.csv', root_dir='./faces',
                                               transform=transforms.Compose([
                                                   Rescale(256),
                                                   RandomCrop(225),
                                                   ToTensor()]))
    dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=2)

    # 对dataloader对象进行迭代,读取数据
    for i_batch, sample_batched in enumerate(dataloader):
        image_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
        print('i_batch: {}, image_batch.size(): {}, landmarks_batch.size(): {}'.format(
            i_batch, image_batch.size(), landmarks_batch.size()))


if __name__ == '__main__':
    # just_see_face_dataset()

    # t_dataset()

    # face_dataset = FaceLandmarksDataset(csv_file='./faces/face_landmarks.csv', root_dir='./faces')
    # one_sample = face_dataset[0]
    # use_transoform(one_sample)

    # union_all_knowledge()

    t_dataloader()





           

继续阅读