天天看點

反向傳播公式推導

參考:《神經網絡與深度學習》

https://legacy.gitbook.com/book/xhhjin/neural-networks-and-deep-learning-zh

該筆記主要是反向傳播公式的推導,了解反向傳播的話建議看其他部落格中更加具體的例子或者吳恩達老師反向傳播介紹的視訊(有具體數字的例子),主要有4個公式的推導:

(BP1) δ j L = ∂ C ∂ z j L = ∂ C ∂ a j L ⋅ σ ′ ( z j L ) \delta_j^L=\frac{\partial C}{\partial z^L_j}=\frac{\partial C}{\partial a^L_j} \cdot \sigma'(z^L_j) \tag{BP1} δjL​=∂zjL​∂C​=∂ajL​∂C​⋅σ′(zjL​)(BP1)

(BP2) δ l = W l + 1 T ⋅ δ l + 1 ⊙ σ ′ ( z l ) \delta ^ {l} = {W^{l+1}} ^\mathsf{T} \cdot \delta^{l+1} \odot \sigma'(z^l) \tag{BP2} δl=Wl+1T⋅δl+1⊙σ′(zl)(BP2)

(BP3) ∂ C ∂ w j k l = ∂ C ∂ z j l ⋅ ∂ z j l ∂ w j k l = δ j l ⋅ a k l − 1 \frac{\partial C}{\partial w^{l}_{jk}} = \frac{\partial C}{\partial z^{l}_{j}} \cdot \frac{\partial z^{l}_{j}}{ \partial w^{l}_{jk}} = \delta_j^l \cdot a_k^{l-1} \tag{BP3} ∂wjkl​∂C​=∂zjl​∂C​⋅∂wjkl​∂zjl​​=δjl​⋅akl−1​(BP3)

(BP4) ∂ C ∂ b j l = ∂ C ∂ z j l ⋅ ∂ z j l ∂ b j l = δ j l \frac{\partial C}{\partial b^{l}_{j}} = \frac{\partial C}{\partial z^{l}_{j}} \cdot \frac{\partial z^{l}_{j}}{ \partial b^{l}_{j}} = \delta_j^l \tag{BP4} ∂bjl​∂C​=∂zjl​∂C​⋅∂bjl​∂zjl​​=δjl​(BP4)

  1. 公式BP1的推導:

    先推導神經網絡最後一層L的公式,假設一個二分類的神經網絡在最後一層如下圖所示:

    反向傳播公式推導

    C表示損失(或者叫loss)

    從公式中可以看出: ∂ C ∂ a j L \frac{\partial C}{\partial a^L_j} ∂ajL​∂C​可直接求出來,接下來是推導 ∂ C ∂ z j L \frac{\partial C}{\partial z^L_j} ∂zjL​∂C​, 定義 δ j L = ∂ C ∂ z j L \delta_j^L=\frac{\partial C}{\partial z^L_j} δjL​=∂zjL​∂C​

(1) δ j L = ∂ C ∂ z j L = ∂ C ∂ a j L ⋅ σ ′ ( z j L ) \delta_j^L=\frac{\partial C}{\partial z^L_j}=\frac{\partial C}{\partial a^L_j} \cdot \sigma'(z^L_j) \tag{1} δjL​=∂zjL​∂C​=∂ajL​∂C​⋅σ′(zjL​)(1)

寫成矩陣的形式:

(2) δ L = ∂ C ∂ a L ⊙ σ ′ ( z L ) = Δ a C ⊙ σ ′ ( z L ) \delta^L = \frac{\partial C}{\partial a^L} \odot \sigma'(z^L) = \Delta_aC \odot \sigma'(z^L) \tag{2} δL=∂aL∂C​⊙σ′(zL)=Δa​C⊙σ′(zL)(2)

公式中 ⊙ \odot ⊙表示Hadamard積。

  1. 公式BP2的推導

    L層已經能計算了,接下來是推導前一層的情況,為了表示友善,直接考慮從l+1層到l層的情況,神經網絡以及變量命名示意圖如下圖所示:

    反向傳播公式推導

    (這裡 w j k l + 1 w_{jk}^{l+1} wjkl+1​表示從k神經元到j神經元,了解起來有些拗口,主要是為了友善正向傳播中W矩陣的表示,也可以按照自己喜好來進行變量命名,反向傳播公式推導麻煩很大程度上是因為變量指令麻煩。)

    先得弄清楚前向傳播中的一個公式:

    (3) z j l + 1 = ∑ k = 1 k w j k l + 1 ⋅ a k l + b j l + 1 z_j^{l+1}= \sum_{k=1}^k w_{jk}^{l+1} \cdot a_k^l+b_j^{l+1} \tag{3} zjl+1​=k=1∑k​wjkl+1​⋅akl​+bjl+1​(3)

寫成矩陣形式為:

(4) Z l + 1 = W l + 1 ⋅ A l + B l + 1 Z^{l+1}= W^{l+1} \cdot A^l+B^{l+1} \tag{4} Zl+1=Wl+1⋅Al+Bl+1(4)

在公式(3)中,現在 ∂ C ∂ z j l + 1 \frac{\partial C}{\partial z^{l+1}_j} ∂zjl+1​∂C​已知,需要求 ∂ C ∂ a k l \frac{\partial C}{\partial a^{l}_k} ∂akl​∂C​,自然想到鍊式求導法則:

(5) ∂ C ∂ a k l = ∑ j = 1 j ∂ C ∂ z j l + 1 ⋅ ∂ z j l + 1 ∂ a k l \frac{\partial C}{\partial a^{l}_k}= \sum_{j=1}^j \frac{\partial C}{\partial z^{l+1}_j} \cdot \frac{\partial z^{l+1}_j}{\partial a^{l}_k} \tag{5} ∂akl​∂C​=j=1∑j​∂zjl+1​∂C​⋅∂akl​∂zjl+1​​(5)

也可以從實體意義去了解這個公式, ∂ C ∂ a k l \frac{\partial C}{\partial a^{l}_k} ∂akl​∂C​表示 a k l a^{l}_k akl​的變化對C的影響大小, a k l a^{l}_k akl​可以從 z j l + 1 ( j = 1... J ) z^{l+1}_j(j=1...J) zjl+1​(j=1...J)來影響C的大小,是以公式(5)中需要有累加。

公式(5)繼續化簡,這裡需要參考公式(3):

(6) ∂ C ∂ a k l = ∑ j = 1 j δ j l + 1 ⋅ ∂ z j l + 1 ∂ a k l = ∑ j = 1 j δ j l + 1 ⋅ w j k l + 1 \frac{\partial C}{\partial a^{l}_k} = \sum_{j=1}^j \delta_j^{l+1} \cdot \frac{\partial z^{l+1}_j}{\partial a^{l}_k} = \sum_{j=1}^j \delta_j^{l+1} \cdot w_{jk}^{l+1} \tag{6} ∂akl​∂C​=j=1∑j​δjl+1​⋅∂akl​∂zjl+1​​=j=1∑j​δjl+1​⋅wjkl+1​(6)

(7) ∂ C ∂ z k l = ( ∑ j = 1 j δ j l + 1 ⋅ w j k l + 1 ) ⋅ σ ′ ( z k l ) \frac{\partial C}{\partial z^{l}_k} = (\sum_{j=1}^j \delta_j^{l+1} \cdot w_{jk}^{l+1}) \cdot \sigma'(z_k^l) \tag{7} ∂zkl​∂C​=(j=1∑j​δjl+1​⋅wjkl+1​)⋅σ′(zkl​)(7)

寫成矩陣形式有:

(8) [ ∂ C ∂ z 1 l ∂ C ∂ z 2 l ∂ C ∂ z k l ] = [ w 11 l + 1 w 21 l + 1 w j 1 l + 1 w 12 l + 1 w 22 l + 1 w j 2 l + 1 w 1 k l + 1 w 2 k l + 1 w j k l + 1 ] ⋅ [ δ 1 l + 1 δ 2 l + 1 δ j l + 1 ] ⊙ [ σ ′ ( z 1 l ) σ ′ ( z 2 l ) σ ′ ( z k l ) ] \begin{bmatrix} \frac{\partial C}{\partial z^{l}_1} \\ \frac{\partial C}{\partial z^{l}_2} \\ \frac{\partial C}{\partial z^{l}_k} \end{bmatrix} = \begin{bmatrix} w^{l+1}_{11} & w^{l+1}_{21} & w^{l+1}_{j1} \\ w^{l+1}_{12} & w^{l+1}_{22} & w^{l+1}_{j2} \\ w^{l+1}_{1k} & w^{l+1}_{2k} & w^{l+1}_{jk} \\ \end{bmatrix} \cdot \begin{bmatrix} \delta_1^{l+1} \\ \delta_2^{l+1} \\ \delta_j^{l+1} \end{bmatrix} \odot \begin{bmatrix} \sigma'(z_1^l) \\ \sigma'(z_2^l) \\ \sigma'(z_k^l) \end{bmatrix} \tag{8} ⎣⎢⎡​∂z1l​∂C​∂z2l​∂C​∂zkl​∂C​​⎦⎥⎤​=⎣⎡​w11l+1​w12l+1​w1kl+1​​w21l+1​w22l+1​w2kl+1​​wj1l+1​wj2l+1​wjkl+1​​⎦⎤​⋅⎣⎡​δ1l+1​δ2l+1​δjl+1​​⎦⎤​⊙⎣⎡​σ′(z1l​)σ′(z2l​)σ′(zkl​)​⎦⎤​(8)

(9) [ ∂ C ∂ z 1 l ∂ C ∂ z 2 l ∂ C ∂ z k l ] = [ w 11 l + 1 w 12 l + 1 w 1 k l + 1 w 21 l + 1 w 22 l + 1 w 2 k l + 1 w j 1 l + 1 w j 2 l + 1 w j k l + 1 ] T ⋅ [ δ 1 l + 1 δ 2 l + 1 δ j l + 1 ] ⊙ [ σ ′ ( z 1 l ) σ ′ ( z 2 l ) σ ′ ( z k l ) ] \begin{bmatrix} \frac{\partial C}{\partial z^{l}_1} \\ \frac{\partial C}{\partial z^{l}_2} \\ \frac{\partial C}{\partial z^{l}_k} \end{bmatrix} = \begin{bmatrix} w^{l+1}_{11} & w^{l+1}_{12} & w^{l+1}_{1k} \\ w^{l+1}_{21} & w^{l+1}_{22} & w^{l+1}_{2k} \\ w^{l+1}_{j1} & w^{l+1}_{j2} & w^{l+1}_{jk} \\ \end{bmatrix} ^\mathsf{T} \cdot \begin{bmatrix} \delta_1^{l+1} \\ \delta_2^{l+1} \\ \delta_j^{l+1} \end{bmatrix} \odot \begin{bmatrix} \sigma'(z_1^l) \\ \sigma'(z_2^l) \\ \sigma'(z_k^l) \end{bmatrix} \tag{9} ⎣⎢⎡​∂z1l​∂C​∂z2l​∂C​∂zkl​∂C​​⎦⎥⎤​=⎣⎡​w11l+1​w21l+1​wj1l+1​​w12l+1​w22l+1​wj2l+1​​w1kl+1​w2kl+1​wjkl+1​​⎦⎤​T⋅⎣⎡​δ1l+1​δ2l+1​δjl+1​​⎦⎤​⊙⎣⎡​σ′(z1l​)σ′(z2l​)σ′(zkl​)​⎦⎤​(9)

(10) δ l = W l + 1 T ⋅ δ l + 1 ⊙ σ ′ ( z l ) \delta ^ {l} = {W^{l+1}} ^\mathsf{T} \cdot \delta^{l+1} \odot \sigma'(z^l) \tag{10} δl=Wl+1T⋅δl+1⊙σ′(zl)(10)

3.公式BP3的推導:

然後推導 ∂ C ∂ w j k l \frac{\partial C}{\partial w^{l}_{jk}} ∂wjkl​∂C​ 和 ∂ C ∂ b l \frac{\partial C}{\partial b^{l}} ∂bl∂C​, 這也是神經網絡中實際參數更新需要計算的參數,先推導 ∂ C ∂ w j k l \frac{\partial C}{\partial w^{l}_{jk}} ∂wjkl​∂C​。

根據公式(3)可知:

(11) ∂ C ∂ w j k l = ∂ C ∂ z j l ⋅ ∂ z j l ∂ w j k l = δ j l ⋅ a k l − 1 \frac{\partial C}{\partial w^{l}_{jk}} = \frac{\partial C}{\partial z^{l}_{j}} \cdot \frac{\partial z^{l}_{j}}{ \partial w^{l}_{jk}} = \delta_j^l \cdot a_k^{l-1} \tag{11} ∂wjkl​∂C​=∂zjl​∂C​⋅∂wjkl​∂zjl​​=δjl​⋅akl−1​(11)

4.公式BP4的推導:

然後推導 ∂ C ∂ b l \frac{\partial C}{\partial b^{l}} ∂bl∂C​

(12) ∂ C ∂ b j l = ∂ C ∂ z j l ⋅ ∂ z j l ∂ b j l = δ j l \frac{\partial C}{\partial b^{l}_{j}} = \frac{\partial C}{\partial z^{l}_{j}} \cdot \frac{\partial z^{l}_{j}}{ \partial b^{l}_{j}} = \delta_j^l \tag{12} ∂bjl​∂C​=∂zjl​∂C​⋅∂bjl​∂zjl​​=δjl​(12)

繼續閱讀