天天看點

Pytorch 04: Pytorch中資料加載---Dataset類和DataLoader類

從代碼角度學習了解Pytorch學習架構04: Dataset類和DataLoader類了解,友善我們加載和處理資料。

# coding=utf-8
import matplotlib as mpl
mpl.use('tkagg')   # 調試:agg;  運作: tkagg
import matplotlib.pyplot as plt

import os
import pandas as pd
import torch

"""
torch.utils.data.Dataset 是一個表示資料集的抽象類.
你自己的資料集一般應該繼承``Dataset``, 并且重寫下面的方法:
    1. __len__ 使用``len(dataset)`` 可以傳回資料集的大小
    2. __getitem__ 支援索引, 以便于使用 dataset[i] 可以 擷取第i個樣本(0索引)
"""
from torch.utils.data import Dataset


"""
torch.utils.data中的DataLoader提供為Dataset類對象提供了:
    1.批量讀取資料
    2.打亂資料順序
    3.使用multiprocessing并行加載資料
    
    DataLoader中的一個參數collate_fn:可以使用它來指定如何精确地讀取一批樣本,
     merges a list of samples to form a mini-batch.
    然而,預設情況下collate_fn在大部分情況下都表現很好
"""
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from skimage import io, transform
import numpy as np


def just_see_face_dataset():
    """
    摸一下資料
    :return: 
    """
    landmarks_frame = pd.read_csv('./faces/face_landmarks.csv')
    n = 65
    img_name = landmarks_frame.iloc[n, 0]
    landmarks = landmarks_frame.iloc[n, 1:].as_matrix()    # as_matrix()
    landmarks = landmarks.astype('float').reshape(-1, 2)
    print('img_name: {}'.format(img_name))
    print('landmarks shape: {}'.format(landmarks.shape))
    print('first 4 landmarks: {}'.format(landmarks[:4]))

    plt.figure()
    show_landmarks(io.imread(os.path.join('faces', img_name)), landmarks)
    plt.show()


def show_landmarks(image, landmarks):
    """
    顯示一張圖檔和它對應的标記點
    :param image:
    :param landmarks:
    :return:
    """
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)


class FaceLandmarksDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        """
        繼承 Dataset 類後,必須重寫的一個方法
        傳回資料集的大小
        :return:
        """
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        """
        繼承 Dataset 類後,必須重寫的一個方法
        傳回第 idx 個圖像及相關資訊
        :param idx:
        :return:
        """
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample


def t_dataset():
    """
    測試 FaceLandmarksDataset 類的使用
    :return: 
    """
    # 實列化 FaceLandmarksDataset 類
    face_dataset = FaceLandmarksDataset(csv_file='./faces/face_landmarks.csv', root_dir='./faces')
    fig = plt.figure()
    length_dataset = len(face_dataset)

    for i in range(length_dataset):
        # 注: Dataset 類對象可以直接索引[i]通路
        sample = face_dataset[i]
        print(i, sample['image'].shape, sample['landmarks'].shape)

        ax = plt.subplot(1, 4, i + 1)
        plt.tight_layout()
        ax.set_title('sample #{}'.format(i))
        ax.axis('off')
        show_landmarks(sample['image'], sample['landmarks'])
        if i == 3:
            plt.show()
            break


"""Transform操作"""
class Rescale(object):
    """按照給定尺寸更改一個圖像的尺寸
    
    Args:
        output_size (tuple or int): 要求輸出的尺寸.  如果是個元組類型, 輸出
        和output_size比對. 如果時int類型,圖檔的短邊和output_size比對, 圖檔的
        長寬比保持不變.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # 對于标記點, h和w需要交換位置, 因為對于圖像, x和y分别時第1維和第0維
        landmarks = landmarks * [new_w / w, new_h / h]

        # 傳回值實際上也是一個sample
        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """随機裁剪圖檔

    Args:
        output_size (tuple or int): 期望輸出的尺寸, 如果時int類型, 裁切成正方形.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        # 傳回值實際上也是一個sample
        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """
    将 ndarray 的樣本轉化為 Tensor 的樣本
    """
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # 交換軸,因為 numpy 圖檔:H x W x C, torch輸入圖檔要求: C x H x W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}


def use_transoform(one_sample):
    """
    示範如何使用 transform: 把幾種 transform 組合在一起
    :return: 
    """
    # transforms.Compose 隻是将這兩種tranform組合在一起,按順序對sample進行處理
    composed = transforms.Compose([Rescale(256), RandomCrop(224)])
    transfromed_sample = composed(one_sample)
    plt.figure()
    show_landmarks(transfromed_sample['image'], transfromed_sample['landmarks'])
    plt.show()


def union_all_knowledge():
    """
    疊代整個資料集:
        每次疊代資料,都會1.從檔案中讀取圖像    2.對所讀取的圖像應用上述變換transform。 進而對資料集進行增強操作
    :return: 
    """
    transformed_dataset = FaceLandmarksDataset(csv_file='./faces/face_landmarks.csv', root_dir='./faces',
                                               transform=transforms.Compose([
                                                   Rescale(256),
                                                   RandomCrop(225),
                                                   ToTensor()]))
    for i in range(len(transformed_dataset)):
        sample = transformed_dataset[i]
        print(i, sample['image'].size(), sample['landmarks'].size())
        if i == 3:
            break


def t_dataloader():
    transformed_dataset = FaceLandmarksDataset(csv_file='./faces/face_landmarks.csv', root_dir='./faces',
                                               transform=transforms.Compose([
                                                   Rescale(256),
                                                   RandomCrop(225),
                                                   ToTensor()]))
    dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=2)

    # 對dataloader對象進行疊代,讀取資料
    for i_batch, sample_batched in enumerate(dataloader):
        image_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
        print('i_batch: {}, image_batch.size(): {}, landmarks_batch.size(): {}'.format(
            i_batch, image_batch.size(), landmarks_batch.size()))


if __name__ == '__main__':
    # just_see_face_dataset()

    # t_dataset()

    # face_dataset = FaceLandmarksDataset(csv_file='./faces/face_landmarks.csv', root_dir='./faces')
    # one_sample = face_dataset[0]
    # use_transoform(one_sample)

    # union_all_knowledge()

    t_dataloader()





           

繼續閱讀