天天看點

EfficientNet遷移學習(二) —— 資料讀取(DataLoader.py)

子產品介紹

資料讀取部分通常包含如下功能:

  1. 建立資料類,友善管理。
  2. 類的功能1,解析資料文本,擷取樣本名字等。
  3. 類的功能2,資料增強。
  4. 類的功能3,資料實時處理,比如讀取圖像,标準化等。
  5. 建立2個隊列,訓練集隊列和驗證集隊列。

代碼架構

第一步:建立資料讀取類(DataLoader),類的初始化函數通常用于解析訓練資料的文本,擷取其中的檔案名。
class DataLoader:
    def __init__(self, file, mode):
        self.input = cfg.train.input_size
        self.root_path = cfg.train.root_path

        # read text file: save train name list
        self.name_list = []

        data = open(file, 'r')
        for line in data:
            line = line.strip()
            self.name_list.append(line)
        random.shuffle(self.name_list)

        # if mode == 'train':
        #     self.name_list = self.name_list[0:150]
        #     data1 = open('../B7Data/1025_color_new/split_1.txt', 'r')
        #     for line1 in data1:
        #         line1 = line1.strip()
        #         self.name_list.append(line1)
        #     random.shuffle(self.name_list)
        #     print('混合的訓練集數量: ', len(self.name_list)) 
           
第二步:資料增強,圖像的基本處理方法,包含顔色增強系列,随機翻轉,随機旋轉等。
def image_enhance(self, img):
        p = random.randint(1, 3)
        a1 = random.uniform(0.8, 2)
        a2 = random.uniform(0.8, 1.4)
        a3 = random.uniform(0.8, 1.7)
        a4 = random.uniform(0.8, 2.5)
        img = Image.fromarray(img)

        img = ImageEnhance.Color(img).enhance(a1) if p == 0 else img
        img = ImageEnhance.Brightness(img).enhance(a2) if p == 1 else img
        img = ImageEnhance.Contrast(img).enhance(a3) if p == 2 else img
        img = ImageEnhance.Sharpness(img).enhance(a4) if p == 3 else img
        img = np.array(img)

        return img

    def flip_img(self, img):
        flipped = (np.random.random() <= 0.5)

        if flipped:
            img = img[:, ::-1, :]

        return img

    @staticmethod
    def show_image(name, data):
        cv2.imshow(name, data)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    def pose_rotation(self, img):
        w, h, c = img.shape
        deg = random.uniform(-15.0, 15.0)
        M_rotate = affine_rotation_matrix(angle=deg)
        transform_matrix = transform_matrix_offset_center(M_rotate, x=w, y=h)

        img_result = affine_transform_cv2(img, transform_matrix)

        return img_result 
           
第三步:資料的實時處理,比如讀取圖像,資料标準化,然後将資料以一個batch塊的方式放入隊列。
def load_data(self, batch_size, queue, thread, name_queue, mode):
        """
        從名字隊列中逐個讀取訓練資料,按一個batch存儲
        :param batch_size: 批次大小
        :param queue: 存儲資料的隊列
        :param thread: 配置設定的線程數
        :param name_queue: 訓練資料集的名字隊列
        :param mode: train or valid mode, 對應訓練集和驗證集上不同的處理方式
        :return: 存有資料的queue
        """
        image = []
        label = []
        data_name = []
        thread_name = []

        sign_0 = 0
        sign_1 = 0

        while 1:
            data = name_queue.get()
            # print('data: ', data)
            d1 = data.split(' ')

            # 讀取資料:解析資料集名字和标簽
            if len(d1) == 2:
                data_image = self.root_path + d1[0]
                # print("data path: ", data_image)
                if float(d1[-1]) >= 10:
                    data_label = 1
                    sign_1 = sign_1 + 1

                # elif float(d1[-1]) >= 10:
                #     continue
                else:
                    data_label = 0
                    sign_0 = sign_0 + 1

            else:
                ss = ' '.join(d1[:-1])
                data_image = self.root_path + ss

                if float(d1[-1]) >= 10:
                    data_label = 1
                    sign_1 = sign_1 + 1
                # elif float(d1[-1]) >= 10:
                #     continue
                else:
                    data_label = 0
                    sign_0 = sign_0 + 1
            if not os.path.exists(data_image):
                print('資料不存在:', data_image)

            img = cv2.imread(data_image)
            # self.show_image('ori image', img)

            # 資料增強
            if mode == 'train':
                # self.show_image('ori image', img)

                # img = self.image_enhance(img)
                # self.show_image('enhance', img)

                img = self.flip_img(img)
                # self.show_image('flip', img)

                # img = self.pose_rotation(img)

            # img = cv2.resize(img, (self.input[1], self.input[0]))
            # self.show_image('resize', img)

            img = img.astype(np.float32)

            # "../B7Data/1025_color_new/train_1025_color.txt" 計算如下的均值和标準差
            std_bgr = [6.550119402970968, 6.312448303275082, 8.977662213055952]
            mean_bgr = [33.65559903, 120.61937841, 116.81338165]

            # img = (img - np.mean(img, axis=(0, 1))) / (np.std(img, axis=(0, 1)) + 1e-8)

            # img = (img - np.array([33.65559903, 120.61937841, 116.81338165]))/np.array(std_bgr)
            img = img/255.0
            # img = np.reshape(img, (self.input[0], self.input[1], 3))
            # print(img)

            if mode == 'train-l':
                if (sign_0 <= 16) and (float(d1[-1]) < 10):
                    data_name.append(data_image)
                    image.append(img)
                    label.append(data_label)

                if (sign_1 <= 16) and (float(d1[-1]) >= 10):
                    data_name.append(data_image)
                    image.append(img)
                    label.append(data_label)
            else:
                data_name.append(data_image)
                image.append(img)
                thread_name.append(thread)
                label.append(data_label)

            if len(image) != batch_size:
                continue

            queue.put([data_name, thread_name, np.array(image), np.array(label)])
            # print('名字: ', data_name)
            # print('線程: ', thread_name)

            image = []
            label = []
            data_name = []
            thread_name = []

           
第四步:建立2個隊列,分别動态的讀取訓練資料和驗證集資料
def train_set_queue():
    """
    讀取訓練集資料訓練,主要分為三個部分,如下所示:
        1.讀取訓練集名字:從文本中讀取所有訓練集名字,存入清單.
        2.建立隊列,動态讀取訓練集名字至隊列:配置設定單獨的線程, 讀取所有訓練集名字至train_name_queue.
        3.建立訓練集隊列, 讀取資料;建立多個線程, 同時從train_name_queue中擷取名字,并根據名字從硬碟讀資料至train_queue.
    :return: 存儲訓練集資料的隊列
    """
    train_file = cfg.train.train_set
    human_data_train = DataLoader(train_file, 'train')
    print("num of train data: ", len(human_data_train.name_list))

    # 單線程讀取訓練集名字
    train_name_queue = Queue(cfg.train.train_num)
    name_process = Process(target=human_data_train.name_queue_, args=(train_name_queue, ))
    name_process.start()

    # create queue and read train data
    cache_train_data = 100
    train_thread_num = 4

    train_queue = Queue(cache_train_data)
    for thread in range(train_thread_num):
        p_train = Process(target=human_data_train.load_data,
                          args=(cfg.train.batch_size, train_queue, thread, train_name_queue, 'train'))
        p_train.start()
    return train_queue


def valid_set_queue():
    """
    處理流程與train_set_queue()函數一樣.
    :return: 存儲驗證集的名字.
    """
    valid_file = cfg.train.valid_set
    human_data_valid = DataLoader(valid_file, 'valid')
    print("num of valid data: ", len(human_data_valid.name_list))

    # 單獨隊列,讀取驗證集名字
    valid_name_queue = Queue(cfg.train.valid_num)
    valid_name_process = Process(target=human_data_valid.name_queue_, args=(valid_name_queue,))
    valid_name_process.start()

    # create queue and read valid data
    cache_valid_data = 32
    valid_thread_num = 2

    valid_queue = Queue(cache_valid_data)
    for thread in range(valid_thread_num):
        p_valid = Process(target=human_data_valid.load_data,
                          args=(cfg.train.batch_size, valid_queue, thread, valid_name_queue, 'valid'))
        p_valid.start()
    return valid_queue ```
           

完整代碼

下面的代碼是

DataLoader.py

import random
from PIL import Image, ImageEnhance
from cv_rotation import *
from config import cfg
from multiprocessing import Queue, Process
import os


class DataLoader:
    def __init__(self, file, mode):
        self.input = cfg.train.input_size
        self.root_path = cfg.train.root_path

        # read text file: save train name list
        self.name_list = []

        data = open(file, 'r')
        for line in data:
            line = line.strip()
            self.name_list.append(line)
        random.shuffle(self.name_list)

        # if mode == 'train':
        #     self.name_list = self.name_list[0:150]
        #     data1 = open('../B7Data/1025_color_new/split_1.txt', 'r')
        #     for line1 in data1:
        #         line1 = line1.strip()
        #         self.name_list.append(line1)
        #     random.shuffle(self.name_list)
        #     print('混合的訓練集數量: ', len(self.name_list))

    def name_queue_(self, name_queue):
        count = 0
        random.shuffle(self.name_list)
        while True:
            if count >= len(self.name_list):
                count = 0
                random.shuffle(self.name_list)
                continue

            name_queue.put(self.name_list[count])
            # print(self.name_list[count])
            count = count + 1
            # if name_queue.full():
            #     print('隊列滿')
            #     print('count: ', count)

    def image_enhance(self, img):
        p = random.randint(1, 3)
        a1 = random.uniform(0.8, 2)
        a2 = random.uniform(0.8, 1.4)
        a3 = random.uniform(0.8, 1.7)
        a4 = random.uniform(0.8, 2.5)
        img = Image.fromarray(img)

        img = ImageEnhance.Color(img).enhance(a1) if p == 0 else img
        img = ImageEnhance.Brightness(img).enhance(a2) if p == 1 else img
        img = ImageEnhance.Contrast(img).enhance(a3) if p == 2 else img
        img = ImageEnhance.Sharpness(img).enhance(a4) if p == 3 else img
        img = np.array(img)

        return img

    def flip_img(self, img):
        flipped = (np.random.random() <= 0.5)

        if flipped:
            img = img[:, ::-1, :]

        return img

    @staticmethod
    def show_image(name, data):
        cv2.imshow(name, data)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    def pose_rotation(self, img):
        w, h, c = img.shape
        deg = random.uniform(-15.0, 15.0)
        M_rotate = affine_rotation_matrix(angle=deg)
        transform_matrix = transform_matrix_offset_center(M_rotate, x=w, y=h)

        img_result = affine_transform_cv2(img, transform_matrix)

        return img_result

    def load_data(self, batch_size, queue, thread, name_queue, mode):
        """
        從名字隊列中逐個讀取訓練資料,按一個batch存儲
        :param batch_size: 批次大小
        :param queue: 存儲資料的隊列
        :param thread: 配置設定的線程數
        :param name_queue: 訓練資料集的名字隊列
        :param mode: train or valid mode, 對應訓練集和驗證集上不同的處理方式
        :return: 存有資料的queue
        """
        image = []
        label = []
        data_name = []
        thread_name = []

        sign_0 = 0
        sign_1 = 0

        while 1:
            data = name_queue.get()
            # print('data: ', data)
            d1 = data.split(' ')

            # 讀取資料:解析資料集名字和标簽
            if len(d1) == 2:
                data_image = self.root_path + d1[0]
                # print("data path: ", data_image)
                if float(d1[-1]) >= 10:
                    data_label = 1
                    sign_1 = sign_1 + 1

                # elif float(d1[-1]) >= 10:
                #     continue
                else:
                    data_label = 0
                    sign_0 = sign_0 + 1

            else:
                ss = ' '.join(d1[:-1])
                data_image = self.root_path + ss

                if float(d1[-1]) >= 10:
                    data_label = 1
                    sign_1 = sign_1 + 1
                # elif float(d1[-1]) >= 10:
                #     continue
                else:
                    data_label = 0
                    sign_0 = sign_0 + 1
            if not os.path.exists(data_image):
                print('資料不存在:', data_image)

            img = cv2.imread(data_image)
            # self.show_image('ori image', img)

            # 資料增強
            if mode == 'train':
                # self.show_image('ori image', img)

                # img = self.image_enhance(img)
                # self.show_image('enhance', img)

                img = self.flip_img(img)
                # self.show_image('flip', img)

                # img = self.pose_rotation(img)

            # img = cv2.resize(img, (self.input[1], self.input[0]))
            # self.show_image('resize', img)

            img = img.astype(np.float32)

            # "../B7Data/1025_color_new/train_1025_color.txt" 計算如下的均值和标準差
            std_bgr = [6.550119402970968, 6.312448303275082, 8.977662213055952]
            mean_bgr = [33.65559903, 120.61937841, 116.81338165]

            # img = (img - np.mean(img, axis=(0, 1))) / (np.std(img, axis=(0, 1)) + 1e-8)

            # img = (img - np.array([33.65559903, 120.61937841, 116.81338165]))/np.array(std_bgr)
            img = img/255.0
            # img = np.reshape(img, (self.input[0], self.input[1], 3))
            # print(img)

            if mode == 'train-l':
                if (sign_0 <= 16) and (float(d1[-1]) < 10):
                    data_name.append(data_image)
                    image.append(img)
                    label.append(data_label)

                if (sign_1 <= 16) and (float(d1[-1]) >= 10):
                    data_name.append(data_image)
                    image.append(img)
                    label.append(data_label)
            else:
                data_name.append(data_image)
                image.append(img)
                thread_name.append(thread)
                label.append(data_label)

            if len(image) != batch_size:
                continue

            queue.put([data_name, thread_name, np.array(image), np.array(label)])
            # print('名字: ', data_name)
            # print('線程: ', thread_name)

            image = []
            label = []
            data_name = []
            thread_name = []


def train_set_queue():
    """
    讀取訓練集資料訓練,主要分為三個部分,如下所示:
        1.讀取訓練集名字:從文本中讀取所有訓練集名字,存入清單.
        2.建立隊列,動态讀取訓練集名字至隊列:配置設定單獨的線程, 讀取所有訓練集名字至train_name_queue.
        3.建立訓練集隊列, 讀取資料;建立多個線程, 同時從train_name_queue中擷取名字,并根據名字從硬碟讀資料至train_queue.
    :return: 存儲訓練集資料的隊列
    """
    train_file = cfg.train.train_set
    human_data_train = DataLoader(train_file, 'train')
    print("num of train data: ", len(human_data_train.name_list))

    # 單線程讀取訓練集名字
    train_name_queue = Queue(cfg.train.train_num)
    name_process = Process(target=human_data_train.name_queue_, args=(train_name_queue, ))
    name_process.start()

    # create queue and read train data
    cache_train_data = 100
    train_thread_num = 4

    train_queue = Queue(cache_train_data)
    for thread in range(train_thread_num):
        p_train = Process(target=human_data_train.load_data,
                          args=(cfg.train.batch_size, train_queue, thread, train_name_queue, 'train'))
        p_train.start()
    return train_queue


def valid_set_queue():
    """
    處理流程與train_set_queue()函數一樣.
    :return: 存儲驗證集的名字.
    """
    valid_file = cfg.train.valid_set
    human_data_valid = DataLoader(valid_file, 'valid')
    print("num of valid data: ", len(human_data_valid.name_list))

    # 單獨隊列,讀取驗證集名字
    valid_name_queue = Queue(cfg.train.valid_num)
    valid_name_process = Process(target=human_data_valid.name_queue_, args=(valid_name_queue,))
    valid_name_process.start()

    # create queue and read valid data
    cache_valid_data = 32
    valid_thread_num = 2

    valid_queue = Queue(cache_valid_data)
    for thread in range(valid_thread_num):
        p_valid = Process(target=human_data_valid.load_data,
                          args=(cfg.train.batch_size, valid_queue, thread, valid_name_queue, 'valid'))
        p_valid.start()
    return valid_queue

           

上述代碼中,在資料增強的随機旋轉中,需要從

cv_rotation.py

中擷取相應的實作函數,完整代碼如下:

import numpy as np
import cv2
# import scipy.ndimage as ndi


def affine_rotation_matrix(angle=(-20, 20)):
    """Create an affine transform matrix for image rotation.
    NOTE: In OpenCV, x is width and y is height.

    Parameters
    -----------
    angle : int/float or tuple of two int/float
        Degree to rotate, usually -180 ~ 180.
            - int/float, a fixed angle.
            - tuple of 2 floats/ints, randomly sample a value as the angle between these 2 values.

    Returns
    -------
    numpy.array
        An affine transform matrix.

    """
    if isinstance(angle, tuple):
        theta = np.pi / 180 * np.random.uniform(angle[0], angle[1])
    else:
        theta = np.pi / 180 * angle
    rotation_matrix = np.array([[np.cos(theta), np.sin(theta), 0],
                                [-np.sin(theta), np.cos(theta), 0],
                                [0, 0, 1]])
    return rotation_matrix


def affine_horizontal_flip_matrix(prob=0.5):
    """Create an affine transformation matrix for image horizontal flipping.
    NOTE: In OpenCV, x is width and y is height.

    Parameters
    ----------
    prob : float
        Probability to flip the image. 1.0 means always flip.

    Returns
    -------
    numpy.array
        An affine transform matrix.

    """
    factor = np.random.uniform(0, 1)
    if prob >= factor:
        filp_matrix = np.array([[-1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
        return filp_matrix
    else:
        filp_matrix = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
        return filp_matrix


def affine_shift_matrix(wrg=(-0.1, 0.1), hrg=(-0.1, 0.1), w=200, h=200):
    """Create an affine transform matrix for image shifting.
    NOTE: In OpenCV, x is width and y is height.

    Parameters
    -----------
    wrg : float or tuple of floats
        Range to shift on width axis, -1 ~ 1.
            - float, a fixed distance.
            - tuple of 2 floats, randomly sample a value as the distance between these 2 values.
    hrg : float or tuple of floats
        Range to shift on height axis, -1 ~ 1.
            - float, a fixed distance.
            - tuple of 2 floats, randomly sample a value as the distance between these 2 values.
    w, h : int
        The width and height of the image.

    Returns
    -------
    numpy.array
        An affine transform matrix.

    """
    if isinstance(wrg, tuple):
        tx = np.random.uniform(wrg[0], wrg[1]) * w
    else:
        tx = wrg * w
    if isinstance(hrg, tuple):
        ty = np.random.uniform(hrg[0], hrg[1]) * h
    else:
        ty = hrg * h
    shift_matrix = np.array([[1, 0, tx],
                             [0, 1, ty],
                             [0, 0, 1]])
    return shift_matrix


def affine_shear_matrix(x_shear=(-0.1, 0.1), y_shear=(-0.1, 0.1)):
    """Create affine transform matrix for image shearing.
    NOTE: In OpenCV, x is width and y is height.

    Parameters
    -----------
    shear : tuple of two floats
        Percentage of shears for width and height directions.

    Returns
    -------
    numpy.array
        An affine transform matrix.

    """
    # if len(shear) != 2:
    #     raise AssertionError(
    #         "shear should be tuple of 2 floats, or you want to use tl.prepro.shear rather than tl.prepro.shear2 ?"
    #     )
    # if isinstance(shear, tuple):
    #     shear = list(shear)
    # if is_random:
    #     shear[0] = np.random.uniform(-shear[0], shear[0])
    #     shear[1] = np.random.uniform(-shear[1], shear[1])
    if isinstance(x_shear, tuple):
        x_shear = np.random.uniform(x_shear[0], x_shear[1])
    if isinstance(y_shear, tuple):
        y_shear = np.random.uniform(y_shear[0], y_shear[1])

    shear_matrix = np.array([[1, x_shear, 0],
                             [y_shear, 1, 0],
                             [0, 0, 1]])
    return shear_matrix


def affine_zoom_matrix(zoom_range=(0.8, 1.1)):
    """Create an affine transform matrix for zooming/scaling an image's height and width.
    OpenCV format, x is width.

    Parameters
    -----------
    x : numpy.array
        An image with dimension of [row, col, channel] (default).
    zoom_range : float or tuple of 2 floats
        The zooming/scaling ratio, greater than 1 means larger.
            - float, a fixed ratio.
            - tuple of 2 floats, randomly sample a value as the ratio between these 2 values.

    Returns
    -------
    numpy.array
        An affine transform matrix.

    """

    if isinstance(zoom_range, (float, int)):
        scale = zoom_range
    elif isinstance(zoom_range, tuple):
        scale = np.random.uniform(zoom_range[0], zoom_range[1])
    else:
        raise Exception("zoom_range: float or tuple of 2 floats")

    zoom_matrix = np.array([[scale, 0, 0],
                            [0, scale, 0],
                            [0, 0, 1]])
    return zoom_matrix


def affine_respective_zoom_matrix(w_range=0.8, h_range=1.1):
    """Get affine transform matrix for zooming/scaling that height and width are changed independently.
    OpenCV format, x is width.

    Parameters
    -----------
    w_range : float or tuple of 2 floats
        The zooming/scaling ratio of width, greater than 1 means larger.
            - float, a fixed ratio.
            - tuple of 2 floats, randomly sample a value as the ratio between 2 values.
    h_range : float or tuple of 2 floats
        The zooming/scaling ratio of height, greater than 1 means larger.
            - float, a fixed ratio.
            - tuple of 2 floats, randomly sample a value as the ratio between 2 values.

    Returns
    -------
    numpy.array
        An affine transform matrix.

    """

    if isinstance(h_range, (float, int)):
        zy = h_range
    elif isinstance(h_range, tuple):
        zy = np.random.uniform(h_range[0], h_range[1])
    else:
        raise Exception("h_range: float or tuple of 2 floats")

    if isinstance(w_range, (float, int)):
        zx = w_range
    elif isinstance(w_range, tuple):
        zx = np.random.uniform(w_range[0], w_range[1])
    else:
        raise Exception("w_range: float or tuple of 2 floats")

    zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]])
    return zoom_matrix


# affine transform
def transform_matrix_offset_center(matrix, x, y):
    """Convert the matrix from Cartesian coordinates (the origin in the middle of image)
       to Image coordinates (the origin on the top-left of image).

    Parameters
    ----------
    matrix : numpy.array Transform matrix.
    x and y : 2 int, Size of image.

    Returns
    -------
    numpy.array
        The transform matrix.

    Examples
    --------
    - See ``tl.prepro.rotation``, ``tl.prepro.shear``, ``tl.prepro.zoom``.
    """
    o_x = (x - 1) / 2.0
    o_y = (y - 1) / 2.0
    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
    transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
    return transform_matrix


# def affine_transform(x, transform_matrix, channel_index=2, fill_mode='nearest', cval=0., order=1):
#     """Return transformed images by given an affine matrix in Scipy format (x is height).
#
#     Parameters
#     ----------
#     x : numpy.array
#         An image with dimension of [row, col, channel] (default).
#     transform_matrix : numpy.array
#         Transform matrix (offset center), can be generated by ``transform_matrix_offset_center``
#     channel_index : int
#         Index of channel, default 2.
#     fill_mode : str
#         Method to fill missing pixel, default `nearest`, more options `constant`, `reflect` or `wrap`, see `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`__
#     cval : float
#         Value used for points outside the boundaries of the input if mode='constant'. Default is 0.0
#     order : int
#         The order of interpolation. The order has to be in the range 0-5:
#             - 0 Nearest-neighbor
#             - 1 Bi-linear (default)
#             - 2 Bi-quadratic
#             - 3 Bi-cubic
#             - 4 Bi-quartic
#             - 5 Bi-quintic
#             - `scipy ndimage affine_transform <https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.interpolation.affine_transform.html>`__
#
#     Returns
#     -------
#     numpy.array
#         A processed image.
#
#     Examples
#     --------
#     >>> M_shear = tl.prepro.affine_shear_matrix(intensity=0.2, is_random=False)
#     >>> M_zoom = tl.prepro.affine_zoom_matrix(zoom_range=0.8)
#     >>> M_combined = M_shear.dot(M_zoom)
#     >>> transform_matrix = tl.prepro.transform_matrix_offset_center(M_combined, h, w)
#     >>> result = tl.prepro.affine_transform(image, transform_matrix)
#
#     """
#     # transform_matrix = transform_matrix_offset_center()
#     # asdihasid
#     # asd
#
#     x = np.rollaxis(x, channel_index, 0)
#     final_affine_matrix = transform_matrix[:2, :2]
#     final_offset = transform_matrix[:2, 2]
#     channel_images = [
#         ndi.interpolation.
#         affine_transform(x_channel, final_affine_matrix, final_offset, order=order, mode=fill_mode, cval=cval)
#         for x_channel in x
#     ]
#     x = np.stack(channel_images, axis=0)
#     x = np.rollaxis(x, 0, channel_index + 1)
#     return x
#
#
# apply_transform = affine_transform


def affine_transform_cv2(x, transform_matrix, flags=None, border_mode='constant'):
    """Return transformed images by given an affine matrix in OpenCV format (x is width).
       (Powered by OpenCV2, faster than ``tl.prepro.affine_transform``)

    Parameters
    ----------
    x : numpy.array
        An image with dimension of [row, col, channel] (default).
    transform_matrix : numpy.array
        A transform matrix, OpenCV format.
    border_mode : str
        - `constant`, pad the image with a constant value (i.e. black or 0)
        - `replicate`, the row or column at the very edge of the original is replicated to the extra border.

    Examples
    --------
    # >>> M_shear = tl.prepro.affine_shear_matrix(intensity=0.2, is_random=False)
    # >>> M_zoom = tl.prepro.affine_zoom_matrix(zoom_range=0.8)
    # >>> M_combined = M_shear.dot(M_zoom)
    # >>> result = tl.prepro.affine_transform_cv2(image, M_combined)
    """
    rows, cols = x.shape[0], x.shape[1]
    if flags is None:
        flags = cv2.INTER_AREA
    if border_mode is 'constant':
        border_mode = cv2.BORDER_CONSTANT
    elif border_mode is 'replicate':
        border_mode = cv2.BORDER_REPLICATE
    else:
        raise Exception("unsupport border_mode, check cv.BORDER_ for more details.")
    return cv2.warpAffine(x, transform_matrix[0:2, :], (cols, rows), flags=flags, borderMode=border_mode)


def affine_transform_keypoints(coords_list, transform_matrix):
    """Transform keypoint coordinates according to a given affine transform matrix.
       OpenCV format, x is width.

    Note that, for pose estimation task, flipping requires maintaining the left and right body information.
    We should not flip the left and right body, so please use ``tl.prepro.keypoint_random_flip``.

    Parameters
    -----------
    coords_list : list of list of tuple/list
        The coordinates
        e.g., the keypoint coordinates of every person in an image.
    transform_matrix : numpy.array
        Transform matrix, OpenCV format.

    Examples
    ---------
    # >>> # 1. get all affine transform matrices
    # >>> M_rotate = tl.prepro.affine_rotation_matrix(angle=20)
    # >>> M_flip = tl.prepro.affine_horizontal_flip_matrix(prob=1)
    # >>> # 2. combine all affine transform matrices to one matrix
    # >>> M_combined = dot(M_flip).dot(M_rotate)
    # >>> # 3. transfrom the matrix from Cartesian coordinate (the origin in the middle of image)
    # >>> # to Image coordinate (the origin on the top-left of image)
    # >>> transform_matrix = tl.prepro.transform_matrix_offset_center(M_combined, x=w, y=h)
    # >>> # 4. then we can transfrom the image once for all transformations
    # >>> result = tl.prepro.affine_transform_cv2(image, transform_matrix)  # 76 times faster
    # >>> # 5. transform keypoint coordinates
    # >>> coords = [[(50, 100), (100, 100), (100, 50), (200, 200)], [(250, 50), (200, 50), (200, 100)]]
    # >>> coords_result = tl.prepro.affine_transform_keypoints(coords, transform_matrix)
    """
    coords_result_list = []
    for coords in coords_list:
        # print('=====================')
        # print(coords)
        coords = np.asarray(coords)
        coords = coords.transpose([1, 0])
        coords = np.insert(coords, 2, 1, axis=0)

        coords_result = np.matmul(transform_matrix, coords)
        coords_result = coords_result[0:2, :].transpose([1, 0])
        coords_result_list.append(coords_result)
    return coords_result_list


if __name__ == '__main__':
    import os
    render_path = '../train_0628/render/'
    label_path =  '../train_0628/label/'
    r = os.listdir(render_path)
    for j in range(len(r)):
        src = cv2.imread(render_path + r[j])
        h, w, c = src.shape

        label = np.loadtxt(label_path + r[j][:-4] + '.txt')
        coords = label[:, 1:3]

        for i in range(7):
            cv2.circle(src, (int(coords[i][0]), int(coords[i][1])), 5, (255, 0, 0), -1)

        cv2.imshow('src', src)
        cv2.waitKey(0)

        # 1. get all affine transform matrices
        M_rotate = affine_rotation_matrix(angle=19)

        # 2. transform the matrix from Cartesian coordinate (the origin in the middle of image)
        # to Image coordinate (the origin on the top-left of image)
        transform_matrix = transform_matrix_offset_center(M_rotate, x=w, y=h)

        # 3. then we can transform the image once for all transformations
        dst = affine_transform_cv2(src, transform_matrix)  # 76 times faster
        # cv2.imshow('rotate', dst)
        # cv2.waitKey(0)

        # 4. transform keypoint coordinates
        print(coords)

        coords = [[(coords[0, 0], coords[0, 1]),
                   (coords[1, 0], coords[1, 1]),
                   (coords[2, 0], coords[2, 1]),
                   (coords[3, 0], coords[3, 1]),
                   (coords[4, 0], coords[4, 1]),
                   (coords[5, 0], coords[5, 1]),
                   (coords[6, 0], coords[6, 1])]]

        coords_result = affine_transform_keypoints(coords, transform_matrix)
        print(coords_result)

        dst_color = dst
        dst_points_lt = []
        for i in range(7):
            dst_points_lt = (int(coords_result[0][i, 0]), int(coords_result[0][i, 1]))
            cv2.circle(dst_color, dst_points_lt, 3, (0, 0, 255), -1)
        cv2.imshow('result', dst_color)
        cv2.waitKey()
        cv2.destroyAllWindows()
        # coords_x = coords_result[0][:,0]
        # coords_y = coords_result[0][:,1]
        # x_min = coords_x[np.argmin(coords_x)]
        # x_max = coords_x[np.argmax(coords_x)]
        # y_min = coords_y[np.argmin(coords_y)]
        # y_max = coords_y[np.argmax(coords_y)]
        # cv2.rectangle(dst_color,(int(x_min),int(y_min)), (int(x_max),int(y_max)), (255,0,0))

        # cv2.imwrite("rotation_result.jpg", dst_color)
           

繼續閱讀