天天看點

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

本文來自AI新媒體量子位(QbitAI)

Alexandre Attia是《辛普森一家》的狂熱粉絲。他看了一系列辛普森劇集,想建立一個能識别其中人物的神經網絡。

接下來讓我們跟着他的文章來了解下該如何建立一個用于識别《辛普森一家》中各個角色的神經網絡。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

要實作這個項目不是很困難,可能會比較耗時,因為需要手動标注每個人物的多張照片。

目前在網上沒有《辛普森一家》人物的訓練資料集,是以我正在标注各類圖檔來建構訓練資料集。這個資料集的第一個版本已經挂在Kaggle上了,将持續進行更新,希望這個資料集能幫到大家。

在學了用TensorFlow建構不同項目後,我決定用Keras,因為它比TensorFlow更為簡單易上手,而且以TensorFlow作為後端,具有很強的相容性。Keras是Francois Chollet用Python語言編寫的一個深度學習庫。

本文基于卷積神經網絡(CNN)來完成此項目,CNN網絡是一種能夠學習許多特征的多層前饋神經網絡。

該資料集目前有18類,有以下人物:Homer,Marge,Lisa,Bart,Burns,Grampa,Flanders,Moe,Krusty,Sideshow Bob,Skinner,Milhouse等。

我的目标是達到20類,當然類别越多越好。各類樣本的大小不一,圖檔背景也不盡相同,主要是從第4至24季的劇集中提取出來的。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

△ 部分人物的圖檔

在訓練集中,每個人物各大約包括1000個樣本(還在标注資料來達到這個數量)。每個人物不一定處于圖像中間,有時周圍還帶有其他人物。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

△ 人物的樣本量分布

通過label_data.py函數,我們可以從AVI電影中标注資料:得到裁剪後的圖檔(左部分或右部分),或者完整版,然後僅需輸入人物名稱的一部分,如對Charles Montgomery Burns輸入burns。

添加資料時,我也使用了Keras模型。對視訊進行截圖,每一幀可轉化得到3張圖檔,分别是左部分、右部分和完整版,然後通過編寫算法來分類每張圖檔。

之後,我檢查了此算法的分類效果,雖然是手動的,但這是一個漸進的過程,速度将會不斷提升,特别是對出現頻率較低的小類别人物。

在預處理圖檔時,第一步是調整樣本大小。為了節省資料記憶體,先将樣本轉換為float32類型,并除以255進行歸一化。

然後,使用Keras的自帶函數,将各類人物的标簽從名字轉換為數字,再利用one-hot編碼轉換成矢量:

進而,使用sklearn庫的train_test_split函數,将資料集分成訓練集和測試集。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

現在讓我們開始進入最有趣的部分:定義網絡模型。

首先,我們建構了一個前饋網絡,包括4個帶有ReLU激活函數的卷積層和一個全連接配接的隐藏層(随着資料量的增大,可能會進一步加深網絡)。

這個模型與Keras文檔中的CIFAR示例模型比較相近,接下來還會使用更多資料對其他模型進行測試。我還在模型中加入了Dropout層來防止網絡過拟合。在輸出層中,使用softmax函數來輸出各類的所屬機率。

損失函數為分類交叉熵(Categorical Cross Entropy)。優化器optimizer使用了随機梯度下降中的RMS Prop方法,通過該權重臨近視窗的梯度平均值來确定該點的學習率。

這個模型在訓練集上疊代訓練了200次,其中批次大小為32。

由于目前的資料集樣本不多,我還用了資料增強操作,使用Keras庫可以很快地實作。

這實際上是對圖檔進行一些随機變化,如小角度旋轉和加噪聲等,是以輸入模型的樣本都不大相同。這有助于防止模型過拟合,提高模型的泛化能力。

在CPU上訓練模型時會耗費較長時間,是以我使用AWS EC2上的GPU資源:每次疊代需要8秒鐘,一共使用了20分鐘。在訓練深度學習模型時,這已經是較快了。

在200次疊代後,我們畫出了模型名額,可以看出性能已經較為穩定,沒有明顯的過拟合現象,且實際正确率較高。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

△ 訓練時驗證集和訓練集的損失值和正确率

由于目前樣本量較小,是以很難得到準确的模型精度。但随着訓練集樣本的增多,這将更貼近實際的模型性能。我們使用sklearn庫很快地輸出了各類的識别效果。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

△ 各類别的識别效果

從上圖可以看出,模型的正确率(f1-score)較高:除了Lisa,其餘各類的正确率都超過了80%。Lisa類的平均正确率為82%,可能是在樣本中Lisa與其他人物混在一起。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

△ 各類别的交叉關系圖

的确,Lisa樣本中經常帶有Bart,是以正确率較低可能受到Bart的影響。

為了提高模型正确率和減少召回率,我添加了一個門檻值。

在讨論門檻值之前,先介紹下關于召回和正确率的關系圖。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

△ 召回和正确率的關系圖

現在統計下正确預測和錯誤預測的相關資料:最佳機率預測,兩個最相似人物的機率差和标準偏差STD。

正确預測:最大值為0.83,最優點機率差為0.773,STD值為0.21;

錯誤預測:最大值為0.27,最優點機率差為0.092,STD值為0.07。

如果人物1的預測正确率太低,預測人物2時标準偏差太高或是兩個最相似人物間的機率差太低,那麼可以認為網絡沒有學習到這個人物。

是以,對兩個類别,繪制測試集的3個名額,希望找到一個超平面來分離正确預測和錯誤預測。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

△ 測試集中多個名額的散點圖

上圖中,想要通過直線或是設定門檻值,來分離出正确預測和錯誤預測,這是不容易實作的。當然還可以看出,錯誤預測的樣本一般在圖表的左下方,但在這個位置也分布了很多正确預測樣本。如果設定了一個門檻值(關于最相似人物間的機率差和機率),則實際召回率也會降低。

我們希望在提高準确性的同時,而不會很大程度上影響召回率,是以要為每個人物或是低正确率的人物(如Lisa Simpson)來繪制這些散點圖。

此外,對于沒有主角或是不存在人物的樣本,加入門檻值後效果很好。目前我在模型中添加了一個“無人物”的類别,可以添加門檻值來處理。我認為很難在最佳機率預測、機率差和标準偏差之間找到平衡點,是以我重點關注最佳預測機率。

在模型中,很難平衡好召回率與正确率之間的關系,同時也無法同時提高召回率和正确率。是以往往根據實際目标,來提高單個值。

對于預測類别的機率最小值,畫出F1-score、召回率和正确率來比較效果。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

△ 對于所有類别或特定類别,正确率、召回率和F1-score與預測類别機率最小值的關系

從圖10中看出,模型效果取決于不同人物。重點研究Lisa Simpson類别,為該類添加機率最小值0.2可能會提高效果,但是組合所有類别後,這個門檻值并不完全适用。

是以考慮全局效果,對于預測類别的機率最小值,應該增加一個合适的門檻值,且不能位于區間[0.2,0.4]内。

刷劇不忘學CNN:TF+Keras識别辛普森一家人物 | 教程+代碼+資料集準備資料集資料預處理構模組化型訓練模型評估模型添加門檻值來提高正确率關于最佳預測機率的召回率和正确率可視化預測人物相關連結

△ 12個不同人物的實際類别和預測類别

在圖11中,用于分類人物的神經網絡效果很好,故應用到視訊中實時預測。在實際中,每張圖檔的預測時間不超過0.1s,可以做到每秒預測多幀。

1. 辛普森一家的人物資料集:

https://www.kaggle.com/alexattia/the-simpsons-characters-dataset

2. 完整項目代碼:

https://github.com/alexattia/SimpsonRecognition

【完】

本文作者:王小新

原文釋出時間:2017-06-25

繼續閱讀