天天看點

深度學習筆記4:深度神經網絡的正則化

恍恍惚惚,又20天沒寫了。今天筆者要寫的是關于機器學習和深度學習中的一項關鍵技術:正則化。相信在機器學習領域摸爬滾打多年的你一定知道正則化是防止模型過拟合的核心技術之一,關于欠拟合和過拟合的問題,本篇筆者就不再展開來說,筆者年初就在一篇文章中詳細通俗地闡述了過拟合的相關問題。想要看的朋友猛戳 談談過拟合 總的來說,監督機器學習的核心原理莫過于如下公式:

深度學習筆記4:深度神經網絡的正則化

該公式可謂是機器學習中最核心最關鍵最能概述監督學習的核心思想的公式了:所有的有監督機器學習,無非就是正則化參數的同時最小化經驗誤差函數。最小化經驗誤差是為了極大程度的拟合訓練資料,正則化參數是為了防止過分的拟合訓練資料。你看,多麼簡約數學哲學。正如之前所說,監督機器學習是為了讓我們建立的模型能夠發現資料中普遍的一般的規律,這個普遍的一般的規律無論對于訓練集還是未知的測試集,都具有較好的拟合性能。通俗點舉例就是,考試能力很強,應用能力很差,或者是模拟考很強,聯考卻一般。

先不扯遠了,繼續回到公式。第一項經驗誤差函數在機器學習中無疑地位重要,但它不是筆者今天要講的,今天要講的是公式的第二項:正則化項。第二項中 λ 為正則化系數,通常是大于 0 的,是一種調整經驗誤差項和正則化項之間關系的系數。λ = 0 時相當于該公式沒有正則化項,模型全力讨好第一項,将經驗誤差進行最小化,往往這也是最容易發生過拟合的時候。随着 λ 逐漸增大,正則化項在模型選擇中的話語權越來越高,對模型的複雜性的懲罰也越來越厲害。是以,在實際的訓練過程中,λ 作為一種超參數很大程度上決定了模型生死。

L1 和 L2 範數

系數 λ 說完了,然後就是正則化項,正則化項形式有很多,但常見的也就是 L1 和 L2 正則化。下面筆者就帶大家好好拾掇拾掇這些個 L1 L2。

在說常見的 L1 和 L2 之前,先來看一下 L0 正則化。L0 正則化也就是 L0 範數,即矩陣中所有非 0 元素的個數。如何我們在正則化過程中選擇了 L0 範數,那該如何了解這個 L0 呢?其實非常簡單,L0 範數就是希望要正則化的參數矩陣 W 大多數元素都為 0。如此簡單粗暴,讓參數矩陣 W 大多數元素為 0 就是實作稀疏而已。說到這裡,權且打住,想必同樣在機器學習領域摸爬滾打的你一定想問,據我所知稀疏性不通常都是用 L1 來實作的嗎?這裡個中緣由筆者不去細講了,簡單說結論:在機器學習領域,L0 和 L1 都可以實作矩陣的稀疏性,但在實踐中,L1 要比 L0 具備更好的泛化求解特性而廣受青睐。先說了 L1,但還沒解釋 L1 範數是什麼,L1 範數就是矩陣中各元素絕對值之和,正如前述所言,L1 範數通常用于實作參數矩陣的稀疏性。至于為啥要稀疏,稀疏有什麼用,通常是為了特征選擇和易于解釋方面的考慮。

深度學習筆記4:深度神經網絡的正則化
再來看 L2 範數。相較于 L0 和 L1,其實 L2 才是正則化中的天選之子。在各種防止過拟合和正則化處理過程中,L2 正則化可謂風頭無二。L2 範數是指矩陣中各元素的平方和後的求根結果。采用 L2 範數進行正則化的原理在于最小化參數矩陣的每個元素,使其無限接近于 0 但又不像 L1 那樣等于 0,也許你又會問了,為什麼參數矩陣中每個元素變得很小就能防止過拟合?這裡我們就拿深度神經網絡來舉例說明吧。在 L2 正則化中,如何正則化系數變得比較大,參數矩陣 W 中的每個元素都在變小,線性計算的和 Z 也會變小,激活函數在此時相對呈線性狀态,這樣就大大簡化了深度神經網絡的複雜性,因而可以防止過拟合。
深度學習筆記4:深度神經網絡的正則化
至于 L1 和 L2,江湖上還有一些混名,L1 就是江湖上著名的 lasso,L2 呢則是嶺回歸。二者都是對回歸損失函數加一個限制形式,lasso 加的是 L1 範數,嶺回歸加的是 L2 範數。可以從幾何直覺上看看二者的差別。
深度學習筆記4:深度神經網絡的正則化
L1 和 L2 的下降速度
深度學習筆記4:深度神經網絡的正則化
L1 和 L2 的模型空間 神經網絡的正則化 說了半天的範數,下面我們就來看看在神經網絡中如何進行正則化操作防止過拟合。為了跟前面筆記保持一緻,我們在神經網絡訓練過程中繼續采用交叉熵損失函數:
深度學習筆記4:深度神經網絡的正則化

加了正則化項之後,損失函數形式如上所示,損失函數變了,反向傳播的梯度計算也就變了,相應的反向傳播也需要重新定義函數。

帶正則化項的損失函數的定義:

def compute_cost_with_regularization(A3, Y, parameters, lambd): """

Implement the cost function with L2 regularization. See formula (2) above.

Arguments:

A3 -- post-activation, output of forward propagation, of shape (output size, number of examples)

Y -- "true" labels vector, of shape (output size, number of examples)

cost - value of the regularized loss function (formula (2))

parameters -- python dictionary containing parameters of the model

Returns:

"""

m = Y.shape[1]

W1 = parameters["W1"]

W2 = parameters["W2"]

W3 = parameters["W3"]

cross_entropy_cost = compute_cost(A3, Y) # This gives you the cross-entropy part of the cost

L2_regularization_cost = 1/m * lambd/2 * (np.sum(np.square(W1))+np.sum(np.square(W2))+np.sum(np.square(W3)))

cost = cross_entropy_cost + L2_regularization_cost

return cost

反向傳播的函數定義:

def backward_propagation_with_regularization(X, Y, cache, lambd): """

Implements the backward propagation of our baseline model to which we added an L2 regularization.

Arguments:

X -- input dataset, of shape (input size, number of examples)

cache -- cache output from forward_propagation()

Y -- "true" labels vector, of shape (output size, number of examples)

lambd -- regularization hyperparameter, scalar

gradients -- A dictionary with the gradients with respect to each parameter, activation and pre-activation variables

m = X.shape[1]

(Z1, A1, W1, b1, Z2, A2, W2, b2, Z3, A3, W3, b3) = cache

dZ3 = A3 - Y

dW3 = 1./m * np.dot(dZ3, A2.T) + lambd/m * W3

db3 = 1./m * np.sum(dZ3, axis=1, keepdims = True)

dA2 = np.dot(W3.T, dZ3)

dZ2 = np.multiply(dA2, np.int64(A2 > 0))

dW2 = 1./m * np.dot(dZ2, A1.T) + lambd/m * W2

db2 = 1./m * np.sum(dZ2, axis=1, keepdims = True)

dA1 = np.dot(W2.T, dZ2)

dZ1 = np.multiply(dA1, np.int64(A1 > 0))

dW1 = 1./m * np.dot(dZ1, X.T) + lambd/m * W1

db1 = 1./m * np.sum(dZ1, axis=1, keepdims = True)

gradients = {"dZ3": dZ3, "dW3": dW3, "db3": db3,"dA2": dA2, "dZ2": dZ2, "dW2": dW2, "db2": db2, "dA1": dA1,

"dZ1": dZ1, "dW1": dW1, "db1": db1}

return gradients

在執行個體中,加了正則化項和沒加正則化項的模型分類結果可如圖所見:

深度學習筆記4:深度神經網絡的正則化

未經正則化處理的分類模型結果

深度學習筆記4:深度神經網絡的正則化

加上正則化後的模型分類結果

效果顯而易見,加了正則化之後,神經網絡的過拟合情況得到極大的緩解。

原文釋出時間為:2018-09-1

本文作者:louwill

本文來自雲栖社群合作夥伴“

Python愛好者社群

”,了解相關資訊可以關注“

”。