天天看點

Keras:使用預訓練模型遷移學習單通道灰階圖像1. 問題引出    2. 解決方案3. 多程序加速運作 4.使用預訓練模型訓練

目錄

1. 問題引出    

2. 解決方案

2.1. 直接使用convert将L轉為RGB

2.2. 數組拼接方法

3. 多程序加速運作 

4.使用預訓練模型訓練

1. 問題引出    

      最近在做一個圖像分類的項目,由于性能比較差,是以需要嘗試将彩色圖轉為灰階圖進行訓練,進而屏蔽掉顔色對分類結果的影響而着重關注紋理、結構等資訊。由于樣本數量較少,隻有幾百張的樣子,如果自己搭網絡的話從頭訓練的話,勢必會因為樣本數量的問題,無法達到一個滿意的效果,是以考慮借鑒Imagenet的預訓練權重。但是在Imagenet上預訓練的模型(Xception, Resnet, VGG等)都是處理的彩色圖,如果要使用預訓練模型就必須要3通道的圖像。

      搜尋了一下,基本上目前的解決方法:

      暴力的将單通道的圖複制為3份,然後合成為一張RGB圖。顯然,該圖3個通道的數值完全相等,這樣存在很多備援計算,我們稱之為“僞RGB圖”。為了友善起見,自己實作了兩種方法,完成如下轉換:

RGB圖  →  灰階圖   →   僞RGB圖

其中,轉換為灰階圖時,均使用的是如下标準公式:

Keras:使用預訓練模型遷移學習單通道灰階圖像1. 問題引出    2. 解決方案3. 多程式加速運作 4.使用預訓練模型訓練

2. 解決方案

首先,導入必要的包:
from multiprocessing import Pool
from PIL import Image
import numpy as np
import os
           

2.1. 直接使用convert将L轉為RGB

def fakeRgb1(path, dst):
    '''
    方法1:直接使用convert将L轉為RGB
    :param path:圖檔輸出路徑
    :param dst:圖檔輸出路徑
    :return:rgb3個通道值相等的rgb圖像
    '''
    b = Image.open(path)
    # L代表轉換為灰階圖
    if b.mode != 'L':
        L = b.convert('L')
    L = L.convert('RGB')
    # 将圖像轉為數組
    rgb_array = np.asarray(L)
    # 将數組轉換為圖像
    rgb_image = Image.fromarray(rgb_array)
    rgb_image.save(dst + '\\' + path.split('\\')[-1])
    print(dst + '\\' + path.split('\\')[-1])
           

2.2. 數組拼接方法

def fakeRgb2(path, dst):
    '''
    方法二:最原始的拼接數組方法
    :param path:圖檔輸入路徑
    :param dst:圖檔輸出路徑
    :return:rgb3個通道值相等的rgb圖像
    '''

    b = Image.open(path)
    # L代表轉換為灰階圖
    if b.mode != 'L':
        L = b.convert('L')
    # 将圖像轉為數組
    b_array = np.asarray(L)
    # 将3個二維數組重疊為一個三維數組
    rgb_array = np.zeros((b_array.shape[0], b_array.shape[1], 3), "uint8")
    rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2] = b_array, b_array, b_array
    rgb_image = Image.fromarray(rgb_array)
    rgb_image.save(dst + '\\' + path.split('\\')[-1])
    print(dst + '\\' + path.split('\\')[-1])
           

3. 多程序加速運作 

由于是批量處理,是以可能會遇到同時轉換很多張圖檔,那麼這個時候就必須使用多程序加速了,具體的加速方法看我的這篇部落格:

Python:多程序運作含有任意個參數的函數

本文的加速代碼如下:

def get_image_paths(folder):
    return [os.path.join(folder, f) for f in os.listdir(folder)]

if __name__ == '__main__': # 多線程,多參數,starmap版本
    images = get_image_paths(path)
    output = [src for i in images]

    zip_args = list(zip(images, output))
    pool = Pool()
    pool.starmap(fakeRgb2, zip_args)
    pool.close()
    pool.join()
           

4.使用預訓練模型訓練

     這部分就和訓練普通RGB圖像一樣即可,在這裡不贅述。 

繼續閱讀