為何使用fit_generator?
在深度學習中,我們資料通常會很大,即使在使用GPU的情況下,我們如果一次性将所有資料(如圖像)讀入CPU的記憶體中,記憶體很有可能會奔潰。這在實際的項目中很有可能會出現。
如果我們使用fit_generator則可以解決這個問題:
1)fit_generator的參數中有一個是連續不斷的産生資料的函數,被稱為generator。
2)至于這個generator是怎麼産生的,本文不想多說。本文隻想告訴大家怎麼來建構一個實際的generator。
實際的generator例子
import json
import os
import numpy as np
import cv2
from sklearn.utils import shuffle
def cv_imread(filePath):
cv_img = cv2.imdecode(np.fromfile(filePath, dtype=np.uint8), -1)
return cv_img
def load_train(train_path, width, height, batch_size):
classes = np.zeros(61)
root = os.getcwd()
with open(train_path, 'r') as load_f:
load_dict = json.load(load_f)
start = 0
end = batch_size
num_epochs = 0
while True:
images = []
labels = []
number = np.random.random_integers(0, len(load_dict)-1, batch_size)
for image in number:
index = load_dict[image]["disease_class"]
path = load_dict[image]['image_id']
img_path = os.path.join(root, 'new_train', 'images', path)
image_data = cv_imread(img_path)
image_data = cv2.resize(image_data, (width, height), 0, 0, cv2.INTER_LINEAR)
image_data = image_data.astype(np.float32)
image_data = np.multiply(image_data, 1.0 / 255.0)
images.append(image_data)
label = np.zeros(len(classes))
label[index] = 1
labels.append(label)
images = np.array(images)
labels = np.array(labels)
yield images, labels
def load_validate(validate_path, width, height):
root = os.getcwd()
with open(validate_path, 'r') as load_f:
load_dict = json.load(load_f)
# num_image = len(load_dict)
# 隻産生512個資料,避免記憶體過大
while True:
images = []
labels = []
classes = np.zeros(61)
number = np.random.random_integers(0, len(load_dict) - 1, 512)
for image in number:
index = load_dict[image]["disease_class"]
path = load_dict[image]['image_id']
img_path = os.path.join(root, 'AgriculturalDisease_validationset', 'images', path)
image_data = cv_imread(img_path)
image_data = cv2.resize(image_data, (width, height), 0, 0, cv2.INTER_LINEAR)
image_data = image_data.astype(np.float32)
image_data = np.multiply(image_data, 1.0 / 255.0)
images.append(image_data)
label = np.zeros(len(classes))
label[index] = 1
labels.append(label)
images = np.array(images)
labels = np.array(labels)
yield images, labels
以上是切實可行的程式。這裡對上面的程式做一個說明:
1)注意到函數中使用yield傳回資料
2)注意到函數使用while True 來進行循環,目前可以認為這是一個必要的部分,這個函數不停的在while中進行循環
3)由于是在while中進行循環,我們需要在while中進行設定初始化,而不要在while循環外進行初始化;我剛開始在load_validate函數中沒有初始化 images = []和labels = [],導緻程式出錯。因為我在while循環中最後将這兩個資料都變成了numpy的資料格式,當進行第二輪資料産生時,numpy的資料格式是沒有append的函數的,是以會出錯。
4)程式中具體的資料運算不需要太多了解,不過這裡給出一個簡單的說明,以助于了解:在train資料中,我試圖從一個很大的圖檔資料庫中随機選擇batch_size個圖檔,然後進行resize變換。這是一張圖檔的過程。為了讀取多張圖檔,我是先将每一個圖檔都讀入一個清單中,這是為了使用清單的append這個追加資料的功能(我覺得這個功能其實挺好用的),最後,把要訓練的一個batch資料轉成numpy的array格式。
5)除了while True 和 yield,大家留意一下這裡的循環和初始化,比較容易出錯。
最後,這裡也給出這個程式的一些參數設定
個人覺得這裡的參數設定還是不太友善的,需要注意一下。
1)首先,對于train中的資料,batch是從主函數中讀進來的。
下面是fit_generator調用設定
times = 3070
batch_size = 64
model.fit_generator(load_data.load_train(train_path, img_rows, img_cols, batch_size=batch_size),
steps_per_epoch=times,
verbose=1,
epochs=100,
validation_data=load_data.load_validate(validate_path, img_rows, img_cols),
validation_steps=1,
shuffle=True)
先說一下,我的資料數量:訓練集共196434張圖檔,驗證集共4095張圖檔。
我這裡主要說一下訓練集的設定,因為驗證集我還沒有仔細的思考。
由于我的GPU數量隻是1,是以我隻是講batch_size設定為64,是以需要從196434/63=3070次才能将整個資料庫平均取一次。而epochs設定為100,意味着我将這個輪回設定成了100,。
以上有不對的地方,或者可以改進的地方,請各位不吝指教。