Overfitting and regularization(過拟合和正則化)
諾貝爾獎得主,實體學家Enrico Fermi曾和他的同僚讨論過一個作為一個重要的實體問題的解決方案的數學模型。 這個模型取得了很好的實驗效果,但是Fermi 有質疑。他問這個模型可以設定多少個自由變量。回答是4個。Fermi回複,我記得我的一個朋友Johnny von Neumann曾經說過,使用4個參數我可以模拟一個大象,用5個參數我可以讓大象擺動鼻子。
就是說如果那個模型能支援大量的參數,就能過描述相當多的現象。即使這樣一個在可靠的資料下表現良好的模型,也不能證明是好的模型。它可能意味着這個模型有足夠大的freedom,使它能夠描述給定的幾乎所有的資料集,卻不能捕捉到任何表象下的本質。這時,這個模型在給定的資料下可以很好的工作,卻不能适用于新的情景。真正的模型測試是去測試它在從未接觸到的新場景中做預測的能力。
Fermi和von Neumann 對這四個參數的模型表示懷疑。我們的含有30個隐藏層識别MNIST數字的網絡有将近24000個參數!這已經很多了。我們的100個隐藏層的網絡有将近80000個參數,更有甚者,深度神經網絡有時可以包含上百萬甚至數十億個參數。那麼這個網絡的結果可信嗎?
我們簡化一下這個問題,構造一個網絡在新場景中效果不好的情況。我們使用30個隐藏層,它有23860個參數。但是我們不會使用所有的50 000個MNIST訓練圖檔。相反,我們隻用前1000張訓練圖檔。使用限定的集合和讓這個問題表現的更明顯。我們用和之前相似的方式去訓練,使用cross-entropy消耗函數,續寫率eta=0.5,最小集為10.不過,我們訓練400次,比之前多一些,以為我們沒有那樣多的訓練資料。讓我們使用network2去看看消耗函數的變化:
>>> import mnist_loader
>>> training_data, validation_data, test_data = \
... mnist_loader.load_data_wrapper()
>>> import network2
>>> net = network2.Network([784, 30, 10], cost=network2.CrossEntropyCost)
>>> net.large_weight_initializer()
>>> net.SGD(training_data[:1000], 400, 10, 0.5, evaluation_data=test_data,
... monitor_evaluation_accuracy=True, monitor_training_cost=True)
根據結果我們可以畫出消耗函數随着網絡學習的變化曲線:

看起來鼓舞人心,消耗光滑地下降,正是我們期望的。注意我們隻展示了訓練次數為200到339的結果。這使我們看到一個良好的後期學習結果趨勢,就像我們看到的一樣,證明了我們感興趣之處。
現在我們看一下識别精度如何随着測試資料的改變變化:
我已經放大了許多。在前200次訓練(未展示)中精度上升到了82%。然後學習逐漸下降。最後,在第280次時,識别精度停止了增長。後面隻能看到識别精度在280次的識别精度附近較小的随機波動。和簽名的圖檔對比,消耗一直在平滑地下降。單從消耗看起來我們的模型一直在變得“更好”。但測試精度表明這些改善都是假象。就像Fermi 否定的模型一樣,我們的模型在280次之後的訓練已經不再概括資料了。已經不是有效的學習了。我們稱這為網絡在280次後過拟合或過訓練。
你可能會懷疑問題在于我拿訓練資料的消耗和測試資料的識别精度作對比。話句話說,問題是因為我們做了一個牛頭不對馬嘴的比較。如果我們對比訓練資料和測試資料的消耗如何?這樣我們的比較度量就相似了?或者我們應該比較訓練資料和測試資料的識别精度?事實上,同樣的現象表明如何比較并不重要。細節總在變化。比如,我們看測試資料的消耗:
可以看到測試資料的消耗的改善直到15次,之後開始變壞,即使訓練資料的消耗在變好。這也表明我們的模型過拟合了。這是一個難題,我們應該把第15次還是第280次看作影響學習的過拟合點?從實踐來看,我們真正關心的是提升測試資料的識别精度,測試資料的消耗不足以代表識别精度。是以把第280次看作我們網絡的過拟合點更有意義。
從訓練資料的識别精度也可以看到過拟合: 精度一路飙升到了100%!就是說我們的網絡能夠正确地識别出所有1000張訓練圖檔!同時,我們的測試精度之後82.27%。是以我們的網絡學習到了訓練資料的特性,而不是識别資料的共性。就好像我們的網絡隻是記住了訓練資料,并沒有足夠地地了解數字以去歸納測試資料。
過拟合是神經網絡的一個主要問題。特别是在總是有大量權重和偏移量的現代網絡。為了能夠有效的訓練,我們需要一種方式去檢測并預防過拟合。我們希望有技術手段可以減少過拟合的影響。
上面的方式是一種有效的檢測過拟合的方法,在訓練網絡時跟蹤測試資料的識别精度。如果我們發現測試資料的識别精度不再提高,就應該停止訓練。當然,嚴格來說,這并不是是過拟合的标志。可能是測試資料和訓練資料的識别精度同時停止了增長。不過,采用這種方式仍能夠有效地預防過拟合。
實際應用中,我們會對這個方法做一些改變。記住我們加載的MNIST資料有三個集合:
>>> import mnist_loader
>>> training_data, validation_data, test_data = \
... mnist_loader.load_data_wrapper()
直至現在我們隻用到了訓練資料和測試資料,忽略了validation_data(确認資料)。确認資料集包含了10 000張數字圖檔,它們與50 000張訓練圖檔和10 000張測試圖檔都不同。我們使用确認資料集取代測試資料集去預防過拟合。