天天看點

深度網絡的過拟合問題讨論



問題背景

最近做深度學習實驗的時候遇到了一個很棘手的問題,那就是大名鼎鼎的​​“過拟合”​​,直覺地表現在圖中是長這個樣子的,分析來講就是說深度網絡在拟合訓練集的時候是可以很好地實作,Loss很小,Accuracy很大(我這兒能達到99.99%),但是呢,測試集的Loss很大,Accuracy在一個比較低的範圍内波動(我這兒是70%-80%),并沒有像論文中說的那樣,測試集的Loss随着疊代的增加而減小,Accuracy随着疊代的增加而增大。



深度網絡的過拟合問題讨論

如果你沒有看出來上圖有什麼毛病的話,我就放一張理想狀态的結果圖做對比(如下圖粗粗的線),畫的比較挫,但是大概的意思在那兒,随着疊代的增加,訓練集和測試集的精确度應該上升,我們可以容忍測試集的精确度沒有訓練集那麼高,畢竟有拟合的誤差,但是像上圖我做出來的結果那樣,一定是“過拟合”啦。 

深度網絡的過拟合問題讨論

用白話來說“過拟合”就是:老師給你的題你都會做了,考試給你換個花樣你就懵逼了。好,老師給你的題就相當于我們的訓練資料,考試的題相當于測試資料,“過拟合”就是深度網絡把訓練的資料拟合的特别好,但是有點好過頭了,對訓練資料當然是100%好用,但是一來測試資料就瘋了,那這樣的網絡訓練出來其實是沒有用的,訓練集已經是監督學習了,拟合的再好也沒用。

展現在函數上就是下圖

正常是測試資料是一個線性或者二次多項式的分布,如果過拟合了,深度網絡很有可以弄出一個特别複雜的拟合曲線函數,把上面所有的黑點點都穿過,當然訓練資料的誤差超級小,但是測試資料一來整個的誤差就比較高了。

網絡結構介紹

我實驗中用到的深度網絡結構原型是​​Fully Convolutional Networks​​,參考的論文中也叫它​​U-Net​​,總之就是一個用來做圖像分割的深度網絡。示意圖如下:  

深度網絡的過拟合問題讨論

 用​​Keras​​的實作代碼是:

深度網絡的過拟合問題讨論

大概的問題背景和網絡結構介紹完畢,更多實驗Details請參考​​PET/CT images segmentation via U-Net, using keras.​​

問題分析

當年LeNet-5在手寫字的識别上出盡了風頭,但是當LeNet-5應用到其他資料集中的時候卻出現了很多問題,從此,學者們開始了瘋狂的理論、實踐探索。“過拟合”問題算是深度學習中一個特别重要的問題,老生常談了,也有不少解決的方法供我選擇。

舉例來講(感謝“知乎深度學習Keras”——QQ群中大神們的幫助):

1. 加入​​Dropout層​​

2. 檢查資料集是否過小(Data Augmentation)

3. 用一用遷移學習的思想

4. 調參小tricks.

  • 調國小習速率(Learning Rate)
  • 調小每次反向傳播的訓練樣本數(batch_size)

5. 試一試别的優化器(optimizer)

6. Keras的回調函數EarlyStopping()

評價:我認為第一個是比較可行,因為“教科書”上的确有說dropout是專門用來對付“過拟合”問題的。

關于資料集的大小,這也是導緻過拟合的原因,如果太小,很容易過拟合。那麼多大的資料集夠了呢?反正我的肯定夠了,我的深度網絡輸入圖像是369,468幅,68*80像素的,二通道輸入,總共的大小是19.5GB。這個資料量可以說是十分可觀了,是以對我來說,第二條可能不适用。那麼如果想要擴充資料集,需要用到Data Augmentation,這個是在醫學影像中十分常用的手段,包括平移,旋轉,拉伸,扭曲等等變換造出新的資料,來增加資料量。

第三條是深度學習中比較有效的方法了,英文名叫fine-tuning,就是用已有的訓練完的網絡參數作為初始化,在這個基礎上繼續訓練。原來的網絡參數往往會在很多論文和github裡頭能找到,這是在很大的圖像資料集中訓練完的網絡,根據圖形圖像的“語義”相似性(我也不知道該怎麼描述,就是認為世界上的圖檔都有某種相似性,就像人類,每個人都長得不一樣,但是你不會把人和其他動物混在一起,這就是一個宏觀的,抽象的相似性),把這個網絡“遷移”到一個新的圖像資料集中是有一定的道理的。由于時間原因,我暫時還沒有采用這個。

第四條就是比較說不清道不明的調參了,這幾乎是機器學習的主要話題,人說“有多少人工,就有多少智能”,這個調參真的需要“經驗”啊哈哈哈。

不好意思,第五條又是試湊法。。。可供選擇的​​Optimizers有很多​​,都試一下,看看用哪兒效果好,聽上去有點喪心病狂了。

第六條方法是一個小函數,叫做EarlyStopping,代碼如下 

[python]

  1. early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5)  

作用是監視每次疊代的名額,比如說這兒監視的是val_loss(測試集的Loss),随着疊代的增加,當val_loss不再發生大的變化的時候可以終止訓練,在過拟合之前阻斷。這種政策也稱為“No-improvement-in-n”,n就是Epoch的次數。

不幸的是,以上六個方案,我測試了以後都沒有很好地解決“過拟合”問題。

 正則化方法

正則化方法是指在進行目标函數或代價函數優化時,在目标函數或代價函數後面加上一個正則項,一般有L1正則與L2正則等。這個很理論了,會涉及到一些公式。

深度網絡的過拟合問題讨論

​​這部分内容​​我在學習的時候就當它純理論來記,當時根本沒有想過會去真正用它,看來現在是必須要try一下了。

深度網絡的過拟合問題讨論

要在Keras中修改這部分代價函數(​​Objectives​​)的代碼,可以參考​​這部分内容​​,裡面包括了若幹個代價函數,如果想要自己編寫代價函數也可以的。根據這個部落格:​​基于Theano的深度學習(Deep Learning)架構Keras學習随筆-08-規則化(規格化)​,有效解決過拟合的方法就是加入規則項。具體的規則化可以參見​​深度學習(DL)與卷積神經網絡(CNN)學習筆記随筆-04-基于Python的LeNet之MLP​中對于規則化的介紹。部落客​​Tig_Free​是真神啊,膜拜一下!

問題解決

最終,過拟合的現象基本上被控制住了,總的來說,L1/L2規範化的确是很牛逼,在學術論文中也有所展現:

網絡調整如下:

  1. model = Sequential()  
  2.     model.add(Convolution2D(64, 4, 4, border_mode='valid', input_shape=data.shape[-3:]))  
  3.     model.add(Convolution2D(64, 3, 3, border_mode='valid'))  
  4.     model.add(Activation('relu'))  
  5.     model.add(Dropout(0.3))   
  6.     #model.add(BatchNormalization(epsilon=1e-06, mode=0, axis=-1, momentum=0.9, weights=None, beta_init='zero', gamma_init='one'))  
  7.     model.add(MaxPooling2D(pool_size=(2, 2)))  
  8.     model.add(Convolution2D(64, 4, 4, border_mode='valid'))  
  9.     model.add(Convolution2D(64, 3, 3, border_mode='valid'))   
  10.     model.add(Convolution2D(128, 4, 4, border_mode='valid'))  
  11.     model.add(Convolution2D(128, 3, 3, border_mode='valid'))   
  12.     model.add(Activation('relu'))     
  13.     #model.add(BatchNormalization(epsilon=1e-06, mode=0, axis=-1, momentum=0.9, weights=None, beta_init='zero', gamma_init='one'))    
  14.     model.add(Convolution2D(128, 3, 3, border_mode='valid'))  
  15.     model.add(Dropout(0.3))  
  16.     model.add(MaxPooling2D(pool_size=(2, 2)))         
  17.     model.add(Flatten())  
  18.     model.add(Dense(512, init='normal',W_regularizer=l2(0.02), activity_regularizer=activity_l2(0.01)))  
  19.     model.add(Dense(LABELTYPE, init='normal'))  
  20.     model.add(Activation('softmax'))   

配置:)

  • dropout層(0.3)
  • 全連接配接層的L2規範化
  • 優化器(adadelta)
  • 學習速率(1e-9)

繼續閱讀