天天看點

Keras: fit_generator中如何建構一個generator?

為何使用fit_generator?

在深度學習中,我們資料通常會很大,即使在使用GPU的情況下,我們如果一次性将所有資料(如圖像)讀入CPU的記憶體中,記憶體很有可能會奔潰。這在實際的項目中很有可能會出現。

如果我們使用fit_generator則可以解決這個問題:

1)fit_generator的參數中有一個是連續不斷的産生資料的函數,被稱為generator。

2)至于這個generator是怎麼産生的,本文不想多說。本文隻想告訴大家怎麼來建構一個實際的generator。

實際的generator例子

import json
import os
import numpy as np
import cv2
from sklearn.utils import shuffle

def cv_imread(filePath):
    cv_img = cv2.imdecode(np.fromfile(filePath, dtype=np.uint8), -1)

    return cv_img

def load_train(train_path, width, height, batch_size):
    classes = np.zeros(61)
    root = os.getcwd()
    with open(train_path, 'r') as load_f:
        load_dict = json.load(load_f)
        start = 0
        end = batch_size
        num_epochs = 0
        while True:
            images = []
            labels = []
            number = np.random.random_integers(0, len(load_dict)-1, batch_size)
            for image in number:
                index = load_dict[image]["disease_class"]
                path = load_dict[image]['image_id']
                img_path = os.path.join(root, 'new_train', 'images', path)
                image_data = cv_imread(img_path)
                image_data = cv2.resize(image_data, (width, height), 0, 0, cv2.INTER_LINEAR)
                image_data = image_data.astype(np.float32)
                image_data = np.multiply(image_data, 1.0 / 255.0)
                images.append(image_data)
                label = np.zeros(len(classes))
                label[index] = 1
                labels.append(label)
            images = np.array(images)
            labels = np.array(labels)
            yield images, labels

def load_validate(validate_path, width, height):
    root = os.getcwd()
    with open(validate_path, 'r') as load_f:
        load_dict = json.load(load_f)

        # num_image = len(load_dict)
        # 隻産生512個資料,避免記憶體過大
        while True:
            images = []
            labels = []
            classes = np.zeros(61)
            number = np.random.random_integers(0, len(load_dict) - 1, 512)

            for image in number:
                index = load_dict[image]["disease_class"]
                path = load_dict[image]['image_id']
                img_path = os.path.join(root, 'AgriculturalDisease_validationset', 'images', path)
                image_data = cv_imread(img_path)
                image_data = cv2.resize(image_data, (width, height), 0, 0, cv2.INTER_LINEAR)
                image_data = image_data.astype(np.float32)
                image_data = np.multiply(image_data, 1.0 / 255.0)

                images.append(image_data)
                label = np.zeros(len(classes))
                label[index] = 1
                labels.append(label)
            images = np.array(images)
            labels = np.array(labels)

            yield images, labels
           

以上是切實可行的程式。這裡對上面的程式做一個說明:

1)注意到函數中使用yield傳回資料

2)注意到函數使用while True 來進行循環,目前可以認為這是一個必要的部分,這個函數不停的在while中進行循環

3)由于是在while中進行循環,我們需要在while中進行設定初始化,而不要在while循環外進行初始化;我剛開始在load_validate函數中沒有初始化 images = []和labels = [],導緻程式出錯。因為我在while循環中最後将這兩個資料都變成了numpy的資料格式,當進行第二輪資料産生時,numpy的資料格式是沒有append的函數的,是以會出錯。

4)程式中具體的資料運算不需要太多了解,不過這裡給出一個簡單的說明,以助于了解:在train資料中,我試圖從一個很大的圖檔資料庫中随機選擇batch_size個圖檔,然後進行resize變換。這是一張圖檔的過程。為了讀取多張圖檔,我是先将每一個圖檔都讀入一個清單中,這是為了使用清單的append這個追加資料的功能(我覺得這個功能其實挺好用的),最後,把要訓練的一個batch資料轉成numpy的array格式。

5)除了while True 和 yield,大家留意一下這裡的循環和初始化,比較容易出錯。

最後,這裡也給出這個程式的一些參數設定

個人覺得這裡的參數設定還是不太友善的,需要注意一下。

1)首先,對于train中的資料,batch是從主函數中讀進來的。

下面是fit_generator調用設定

times = 3070
batch_size = 64
model.fit_generator(load_data.load_train(train_path, img_rows, img_cols, batch_size=batch_size),
                    steps_per_epoch=times,
                    verbose=1,
                    epochs=100,
                    validation_data=load_data.load_validate(validate_path, img_rows, img_cols),
                    validation_steps=1,
                    shuffle=True)
           

先說一下,我的資料數量:訓練集共196434張圖檔,驗證集共4095張圖檔。

我這裡主要說一下訓練集的設定,因為驗證集我還沒有仔細的思考。

由于我的GPU數量隻是1,是以我隻是講batch_size設定為64,是以需要從196434/63=3070次才能将整個資料庫平均取一次。而epochs設定為100,意味着我将這個輪回設定成了100,。

以上有不對的地方,或者可以改進的地方,請各位不吝指教。