天天看點

微調模型來完成熱狗識别的圖像分類任務

作者:黑馬程式員

我們來實踐一個具體的例子:熱狗識别。将基于一個小資料集對在ImageNet資料集上訓練好的ResNet模型進行微調。該小資料集含有數千張熱狗或者其他事物的圖像。我們将使用微調得到的模型來識别一張圖像中是否包含熱狗。

首先,導入實驗所需的工具包。

import tensorflow as tf
import numpy as np           

擷取資料集

我們首先将資料集放在路徑hotdog/data之下:

微調模型來完成熱狗識别的圖像分類任務

每個類别檔案夾裡面是圖像檔案。

上一節中我們介紹了ImageDataGenerator進行圖像增強,我們可以通過以下方法讀取圖像檔案,該方法以檔案夾路徑為參數,生成經過圖像增強後的結果,并産生batch資料:

flow_from_directory(self, directory,
                            target_size=(256, 256), color_mode='rgb',
                            classes=None, class_mode='categorical',
                            batch_size=32, shuffle=True, seed=None,
                            save_to_dir=None)           

主要參數:

  ▪ directory: 目标檔案夾路徑,對于每一個類對應一個子檔案夾,該子檔案夾中任何JPG、PNG、BNP、PPM的圖檔都可以讀取。

  ▪ target_size: 預設為(256, 256),圖像将被resize成該尺寸。

  ▪ batch_size: batch資料的大小,預設32。

  ▪ shuffle: 是否打亂資料,預設為True。

我們建立兩個tf.keras.preprocessing.image.ImageDataGenerator執行個體來分别讀取訓練資料集和測試資料集中的所有圖像檔案。将訓練集圖檔全部處理為高和寬均為224像素的輸入。此外,我們對RGB(紅、綠、藍)三個顔色通道的數值做标準化。

# 擷取資料集
import pathlib
train_dir = 'transferdata/train'
test_dir = 'transferdata/test'
# 擷取訓練集資料
train_dir = pathlib.Path(train_dir)
train_count = len(list(train_dir.glob('*/*.jpg')))
# 擷取測試集資料
test_dir = pathlib.Path(test_dir)
test_count = len(list(test_dir.glob('*/*.jpg')))
# 建立imageDataGenerator進行圖像處理
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
# 設定參數
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
# 擷取訓練資料
train_data_gen = image_generator.flow_from_directory(directory=str(train_dir),
                                                    batch_size=BATCH_SIZE,
                                                    target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                    shuffle=True)
# 擷取測試資料
test_data_gen = image_generator.flow_from_directory(directory=str(test_dir),
                                                    batch_size=BATCH_SIZE,
                                                    target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                    shuffle=True)           

下面我們随機取1個batch的圖檔然後繪制出來。

import matplotlib.pyplot as plt
# 顯示圖像
def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(15):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.axis('off')
# 随機選擇一個batch的圖像        
image_batch, label_batch = next(train_data_gen)
# 圖像顯示
show_batch(image_batch, label_batch)           
微調模型來完成熱狗識别的圖像分類任務

模型建構與訓練

我們使用在ImageNet資料集上預訓練的ResNet-50作為源模型。這裡指定weights='imagenet'來自動下載下傳并加載預訓練的模型參數。在第一次使用時需要聯網下載下傳模型參數。

Keras應用程式(keras.applications)是具有預先訓練權值的固定架構,該類封裝了很多重量級的網絡架構,如下圖所示:

微調模型來完成熱狗識别的圖像分類任務

實作時執行個體化模型架構:

tf.keras.applications.ResNet50(
    include_top=True, weights='imagenet', input_tensor=None, input_shape=None,
    pooling=None, classes=1000, **kwargs
)           

主要參數:

▪ include_top: 是否包括頂層的全連接配接層。

▪ weights: None 代表随機初始化, 'imagenet' 代表加載在 ImageNet 上預訓練的權值。

▪ input_shape: 可選,輸入尺寸元組,僅當 include_top=False 時有效,否則輸入形狀必須是 (224, 224, 3)(channels_last 格式)或 (3, 224, 224)(channels_first 格式)。它必須為 3 個輸入通道,且寬高必須不小于 32,比如 (200, 200, 3) 是一個合法的輸入尺寸。

在該案例中我們使用resNet50預訓練模型構模組化型:

# 加載預訓練模型
ResNet50 = tf.keras.applications.ResNet50(weights='imagenet', input_shape=(224,224,3))
# 設定所有層不可訓練
for layer in ResNet50.layers:
    layer.trainable = False
# 設定模型
net = tf.keras.models.Sequential()
# 預訓練模型
net.add(ResNet50)
# 展開
net.add(tf.keras.layers.Flatten())
# 二分類的全連接配接層
net.add(tf.keras.layers.Dense(2, activation='softmax'))           

接下來我們使用之前定義好的ImageGenerator将訓練集圖檔送入ResNet50進行訓練。

# 模型編譯:指定優化器,損失函數和評價名額net.compile(optimizer='adam',
            loss='categorical_crossentropy',
            metrics=['accuracy'])# 模型訓練:指定資料,每一個epoch中隻運作10個疊代,指定驗證資料集history = net.fit(
                    train_data_gen,
                    steps_per_epoch=10,
                    epochs=3,
                    validation_data=test_data_gen,
                    validation_steps=10
                    )           
Epoch 1/3
10/10 [==============================] - 28s 3s/step - loss: 0.6931 - accuracy: 0.5031 - val_loss: 0.6930 - val_accuracy: 0.5094
Epoch 2/3
10/10 [==============================] - 29s 3s/step - loss: 0.6932 - accuracy: 0.5094 - val_loss: 0.6935 - val_accuracy: 0.4812
Epoch 3/3
10/10 [==============================] - 31s 3s/step - loss: 0.6935 - accuracy: 0.4844 - val_loss: 0.6933 - val_accuracy: 0.4875           

繼續閱讀