天天看點

Tensorflow2.0中資料增強及訓練過程中資料增強

資料增強:它是正則化的一種形式,使我們的網絡可以更好地将其推廣到我們的測試/驗證集。

ImageDataGenerator工作原理:

ImageDataGenerator接受原始資料,對其進行随機轉換,并僅傳回轉換後的新資料。接受一批用于訓練的圖像;

進行此批處理并對批進行中的每個圖像應用一系列随機變換(包括随機旋轉,調整大小,剪切等);

用新的,随機轉換的批次替換原始批次;

在此随機轉換的批次上訓練CNN(即原始資料本身不用于訓練)。

1)對圖檔進行資料增強,并将其結果儲存到檔案夾中:

""
一: 定義ImageDataGenerator 圖檔生成器
""
from tensorflow.keras.preprocessing.image import ImageDataGenerator


""
二: 封裝flow_from_directory()
其中:
path:檔案讀入的路徑,必須是子檔案夾的上一級(這裡是個坑,不過試一哈就懂了)
target_size:圖檔resize成的尺寸,不設定會預設設定為(256.256)
batch_size:每次輸入的圖檔的數量,例如batch_size=32,一次進行增強的數量為32,
個人經驗:batch_size的大小最好是應該和檔案的數量是可以整除的關系
save_to_dir:增強後圖檔的儲存位置
save_prefix:檔案名加字首,友善檢視
save_format:儲存圖檔的資料格式
産生的圖檔總數:batch_size*6(即range中的數字)
""
gen = datagen.flow_from_directory(
                           path,
                           target_size=(224, 224),
                           batch_size=15,
                           save_to_dir=dst_path,
                           save_prefix='xx',
                           save_format='jpg')

""
三: 調用gen.next()執行增強過程
""
for i in range(6):
    gen.next()
           

2)訓練過程中資料增強

from tensorflow.keras.preprocessing.image import ImageDataGenerator

""
一: 定義ImageDataGenerator
""
datagen = ImageDataGenerator(
        # 布爾值,使輸入資料集去中心化(均值為0), 按feature執行
        featurewise_center=False,

        # 布爾值,使輸入資料的每個樣本均值為0
        samplewise_center=False,

        # 布爾值,将輸入除以資料集的标準差以完成标準化, 按feature執行
        featurewise_std_normalization=False,

        # 布爾值,将輸入的每個樣本除以其自身的标準差
        samplewise_std_normalization=False,

        # 布爾值,對輸入資料施加ZCA白化
        zca_whitening=False,

        # ZCA使用的eposilon,預設1e-6
        zca_epsilon=1e-06,

        # 整數,資料提升時圖檔随機轉動的角度 (deg 0 to 180)
        rotation_range=0,

        # 浮點數,圖檔寬度的某個比例,資料提升時圖檔水準偏移的幅度
        width_shift_range=0.1,

        # 浮點數,圖檔高度的某個比例,資料提升時圖檔豎直偏移的幅度
        height_shift_range=0.1,

        # 浮點數,剪切強度(逆時針方向的剪切變換角度)
        shear_range=0.,

        # 浮點數或形如[lower,upper]的清單,随機縮放的幅度,若為浮點數,則相當于[lower,upper] 
        #  = [1 - zoom_range, 1+zoom_range]
        zoom_range=0.,

        # 浮點數,随機通道偏移的幅度
        channel_shift_range=0.,

        # ‘constant’,‘nearest’,‘reflect’或‘wrap’之一,當進行變換時超出邊界的點将根據本參數 
        # 給定的方法進行處理
        fill_mode='nearest',

        # 浮點數或整數,當fill_mode=constant時,指定要向超出邊界的點填充的值
        cval=0.,

        # 布爾值,是否進行随機水準翻轉
        horizontal_flip=True,

        # 布爾值,是否進行随機豎直翻轉
        vertical_flip=False,

        # 重放縮因子,預設為None. 如果為None或0則不進行放縮,否則會将該數值乘到資料上(在應用其 
        # 他變換之前)
        rescale=None,

        # 将被應用于每個輸入的函數。該函數将在圖檔縮放和資料提升之後運作。該函數接受一個參數, 
        # 為一張圖檔(秩為3的numpy array),并且輸出一個具有相同shape的numpy array
        preprocessing_function=None,

        # 字元串,“channel_first”或“channel_last”之一,代表圖像的通道維的位置。
        data_format=None,

        # 驗證集切分比重 (strictly between 0 and 1)
        validation_split=0.0)

""
二:fit中調用
1):fit ---->fit_generator
2):傳入資料集變為datagen.flow(x_train, y_train, batch_size=batch_size)
""
model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
                    steps_per_epoch=len(x_train) // batch_size,
                    epochs=epochs, verbose=1, workers=4,
                    callbacks=callbacks,
                    use_multiprocessing=False)
           

fit 中的 verbose:

verbose:日志顯示

verbose = 0 為不在标準輸出流輸出日志資訊

verbose = 1 為輸出進度條記錄

verbose = 2 為每個epoch輸出一行記錄

注意: 預設為 1

繼續閱讀