天天看點

PyTorch—ImageFolder/自定義類 讀取圖檔資料

文章目錄

  • ​​一、torchvision 圖像資料讀取 [0, 1]​​
  • ​​二、torchvision 的 Transform​​
  • ​​三、讀取圖像資料類​​
  • ​​3.1 class torchvision.datasets.ImageFolder 預設讀取圖像資料方法:​​
  • ​​3.2 自定義資料讀取方法​​

運作環境安裝 Anaconda | python ==3.6.6

conda install pytorch -c pytorch
pip install config
pip install tqdm             #包裝疊代器,顯示進度條
pip install torchvision
pip install scikit-image      

一、torchvision 圖像資料讀取 [0, 1]

​import torchvision.transforms as transforms​

​​ transforms 子產品提供了一般的圖像轉換操作類。

​​

​class torchvision.transforms.ToTensor​

​​

功能:

把shape=(H x W x C) 的像素值為 [0, 255] 的 PIL.Image 和 numpy.ndarray

轉換成shape=(C x H x W)的像素值範圍為​

​[0.0, 1.0]​

​的 torch.FloatTensor。

​class torchvision.transforms.Normalize(mean, std)​

​​

功能:

此轉換類作用于torch.*Tensor。給定均值(R, G, B)和标準差(R, G, B),用公式channel = (channel - mean) / std進行規範化。

import torchvision 
import torchvision.transforms as transforms 
import cv2 
import numpy as np 
from PIL import Image 

img_path = "./data/timg.jpg" 

# 引入transforms.ToTensor()功能: range [0, 255] -> [0.0,1.0] 
transform1 = transforms.Compose([transforms.ToTensor()])

# 直接讀取:numpy.ndarray 
img = cv2.imread(img_path)
print("img = ", img[0])      #隻輸出其中一個通道
print("img.shape = ", img.shape)

# 歸一化,轉化為numpy.ndarray并顯示
img1 = transform1(img) 
img2 = img1.numpy()*255 
img2 = img2.astype('uint8') 
img2 = np.transpose(img2 , (1,2,0)) 
 
print("img1 = ", img1)
cv2.imshow('img2 ', img2 ) 
cv2.waitKey() 


# PIL 讀取圖像
img = Image.open(img_path).convert('RGB') # 讀取圖像 
img2 = transform1(img) # 歸一化到 [0.0,1.0] 
print("img2 = ",img2) #轉化為PILImage并顯示 
img_2 = transforms.ToPILImage()(img2).convert('RGB') 
print("img_2 = ",img_2) 
img_2.show()


從上到下依次輸出:---------------------------------------------
img =   [[197 203 202]
   [195 203 202]
   ...
   [200 208 207]
   [200 208 207]]
img.shape =  (362, 434, 3)

img1 =  tensor([[[0.7725, 0.7647, 0.7686,  ..., 0.7804, 0.7843, 0.7843],
         [0.7765, 0.7725, 0.7686,  ..., 0.7686, 0.7608, 0.7569],
         [0.7843, 0.7725, 0.7686,  ..., 0.7725, 0.7686, 0.7569],
         ...,

img_transform =  tensor([[[0.7922, 0.7922, 0.7961,  ..., 0.8078, 0.8118, 0.8118],
         [0.7961, 0.8000, 0.7961,  ..., 0.7922, 0.7882, 0.7843],
         [0.8039, 0.8000, 0.7961,  ..., 0.8118, 0.8039, 0.7922],
         ...,      
PyTorch—ImageFolder/自定義類 讀取圖檔資料

transforms.Compose 歸一化到 [-1.0, 1.0 ]

transform2 = transforms.Compose([transforms.ToTensor()])
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))])      

二、torchvision 的 Transform

在深度學習時關于圖像的資料讀取:由于Tensorflow不支援與numpy的無縫切換,導緻難以使用現成的pandas等格式化資料讀取工具,造成了很多不必要的麻煩,而pytorch解決了這個問題。

pytorch自定義讀取資料和進行Transform的部分請見文檔:

​​​http://pytorch.org/tutorials/beginner/data_loading_tutorial.html​​

但是按照文檔中所描述所完成的自定義Dataset隻能夠使用自定義的Transform步驟,而torchvision包中已經給我們提供了很多圖像transform步驟的實作,為了使用這些已經實作的Transform步驟,我們可以使用如下方法定義Dataset:

from __future__ import print_function, division 
import os 
import torch 
import pandas as pd 
from PIL import Image 
import numpy as np 
from torch.utils.data import Dataset, DataLoader 
from torchvision import transforms 

class FaceLandmarkDataset(Dataset): 
    def __len__(self) -> int: 
        return len(self.landmarks_frame)
        
    def __init__(self, csv_file: str, root_dir: str, transform=None) -> None: 
        super().__init__() 
        self.landmarks_frame = pd.read_csv(csv_file) 
        self.root_dir = root_dir 
        self.transform = transform 

    def __getitem__(self, index:int): 
        img_name = self.landmarks_frame.ix[index, 0] 
        img_path = os.path.join('./faces', img_name) 
        with Image.open(img_path) as img: 
            image = img.convert('RGB') 
        landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float') 
        landmarks = np.reshape(landmarks,newshape=(-1,2)) 
        if self.transform is not None: 
            image = self.transform(image) 
        return image, landmarks 

########################以上為資料讀取類(傳回:image,landmarks)###############################
trans = transforms.Compose(transforms = [transforms.RandomSizedCrop(size=128), 
                                         transforms.ToTensor()]) 

face_dataset = FaceLandmarkDataset(csv_file='faces/face_landmarks.csv', 
                   root_dir='faces', transform= trans) 
loader = DataLoader(dataset = face_dataset, 
                    batch_size=4,
            shuffle=True,
            num_workers=4)      

三、讀取圖像資料類

3.1 class torchvision.datasets.ImageFolder 預設讀取圖像資料方法:
  • ​__init__​

    ​( 初始化)
  • ​classes, class_to_idx = find_classes(root)​

    ​​ :得到分類的類别名(classes)和類别名與數字類别的映射關系字典(class_to_idx)

    其中 classes (list): List of the class names.

    其中 class_to_idx (dict): Dict with items (class_name, class_index).

  • ​imgs = make_dataset(root, class_to_idx)​

    ​​得到imags清單。

    其中 imgs (list): List of (image path, class_index) tuples

    每個值是一個tuple,每個tuple包含兩個元素:圖像路徑和标簽

  • ​__getitem__​

    ​(圖像擷取)
  • ​path, target = self.imgs[index]​

    ​ 擷取圖像(路徑,标簽)
  • ​img = self.loader(path)​

    ​資料讀取。
  • ​img = self.transform(img)​

    ​資料、标簽 轉換成 tensor
  • ​target = self.target_transform(target)​

  • ​__len__​

    ​( 資料集數量)
  • ​return len(self.imgs)​

class ImageFolder(data.Dataset):
    """預設圖像資料目錄結構
    root
    .
    ├──dog
    |   ├──001.png
    |   ├──002.png
    |   └──...
    └──cat  
    |   ├──001.png
    |   ├──002.png
    |   └──...
    └──...
    """
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        classes, class_to_idx = find_classes(root)
        imgs = make_dataset(root, class_to_idx)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        """
        index (int): Index
    Returns:tuple: (image, target) where target is class_index of the target class.
        """
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def __len__(self):
        return len(self.imgs)      

圖像擷取 ​

​__getitem__​

def pil_loader(path):    # 一般采用pil_loader函數。
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)      
3.2 自定義資料讀取方法
  • ​__init__()​

    ​初始化傳入參數:
  • img_path 裡面為所有圖像資料(包括訓練和測試)

    txt_path 裡面有 train.txt和val.txt兩個檔案:txt檔案中每行都是圖像路徑,tab鍵,标簽。

  • 其中 self.img_name 和 self.img_label 的讀取方式就跟你資料的存放方式有關(需要調整的地方)
  • ​__getitem__()​

    ​依然采用default_loader方法來讀取圖像。
  • ​Transform​

    ​中将每張圖像都封裝成 Tensor
class customData(Dataset):
    def __init__(self, img_path, txt_path, dataset = '',data_transforms=None, loader = default_loader):
        with open(txt_path) as input_file:
            """
        關于json檔案解析:
        https://blog.csdn.net/wsp_1138886114/article/details/83302339
        txt檔案解析如下,具體文本解析具體分析,沒有定數
            """
            lines = input_file.readlines()
            self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
            self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
        self.data_transforms = data_transforms
        self.dataset = dataset
        self.loader = loader

    def __len__(self):
        return len(self.img_name)

    def __getitem__(self, item):
        img_name = self.img_name[item]
        label = self.img_label[item]
        img = self.loader(img_name)

        if self.data_transforms is not None:
            try:
                img = self.data_transforms[self.dataset](img)
            except:
                print("Cannot transform image: {}".format(img_name))
        return img, label
#####################以上為圖像資料讀取,傳回(img, label)#########################

# 保證image_datasets與torchvision.datasets.ImageFolder類傳回的資料類型一樣
image_datasets = {x: customData(img_path='/ImagePath',
                                txt_path=('/TxtFile/' + x + '.txt'),
                                data_transforms=data_transforms,
                                dataset=x) for x in ['train', 'val']}

#用torch.utils.data.DataLoader類,将這個batch的圖像資料和标簽都分别封裝成Tensor。
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                             batch_size=batch_size,
                                             shuffle=True) for x in ['train', 'val']}

# 模型儲存
torch.save(model, 'output/resnet_epoch{}.pkl'.format(epoch))