論文位址
This looks like that: deep learning for interpretable image recognition
prototype layer
prototype layer對應的公式如下:
gpj=maxz⃗ ∈patch(z)−log(||z⃗ −pj||22+ϵ) g p j = max z → ∈ p a t c h ( z ) − l o g ( | | z → − p j | | 2 2 + ϵ )
其中max層可以用現成的maxpool層來實作,下面重點讨論新的層實作
fz=−log(||z⃗ −pj||22+ϵ) f z = − l o g ( | | z → − p j | | 2 2 + ϵ )
這個層的輸入輸出的通道數可以不一緻,但是每個通道内的尺寸需要一緻. 為了後續讨論友善,約定以下符号
* in_data: (batchSize, inChNum, height, width), 輸入資料,即公式中的 z z
* weight: (outChNum, inChNum, 3, 3), filter權重, 此處假設采用3x3的filter
* out_data: (batchSize, outChNum, height, width), 輸出資料,迹公式中的f(z)f(z)
in_data (one sample)
z00 z 00 | z01 z 01 | z02 z 02 |
z10 z 10 | z11 z 11 | z12 z 12 |
z20 z 20 | z21 z 21 | z22 z 22 |
z30 z 30 | z31 z 31 | z32 z 32 |
weight (one sample and one output channel)
w00 w 00 | w01 w 01 |
w10 w 10 | w11 w 11 |
w20 w 20 | w21 w 21 |
out_data (one sample and one output channnel)
a00 a 00 | a01 a 01 | a02 a 02 |
a10 a 10 | a11 a 11 | a12 a 12 |
a20 a 20 | a21 a 21 | a22 a 22 |
forward
按文中所述, 輸出 aij a i j 對應的輸入的patch的左上角是 zij z i j ,而不是常見的中心點.
fz=−log(||z⃗ −pj||22+ϵ) f z = − l o g ( | | z → − p j | | 2 2 + ϵ )
等價于
ai,j=−log∑m=02∑n=02D(wm,n,z(i+m),(j+n)) a i , j = − l o g ∑ m = 0 2 ∑ n = 0 2 D ( w m , n , z ( i + m ) , ( j + n ) )
其中
D(w,z)=(w−z)2 D ( w , z ) = ( w − z ) 2
PS: 考慮下為什麼是 w−z w − z 而不是 z−w z − w
backward
backward包括兩部分
* 權重w的梯度,用來調節w的值
* 輸入z的梯度,用來繼續反向傳播梯度資訊
假設整個網絡的損失函數 J J ,反向傳入的梯度資訊
∂J∂aij=δij∂J∂aij=δij
此處認為已知
grad_z
此時的目标函數用鍊式法則展開有
∂J∂zij=∑m=i−2i∑n=j−2jδmn∂amn∂zij ∂ J ∂ z i j = ∑ m = i − 2 i ∑ n = j − 2 j δ m n ∂ a m n ∂ z i j
其中
∂amn∂zij=2×w(i−m)(j−n)−zij∑2x=0∑2y=0D(wx,y,z(m+x),(n+y))=2×w(i−m)(j−n)−zijexp(−aij) ∂ a m n ∂ z i j = 2 × w ( i − m ) ( j − n ) − z i j ∑ x = 0 2 ∑ y = 0 2 D ( w x , y , z ( m + x ) , ( n + y ) ) = 2 × w ( i − m ) ( j − n ) − z i j e x p ( − a i j )
則
∂J∂zij=∑m=i−2i∑n=j−2j2δmnw(i−m)(j−n)−zijexp(−aij)=2exp(−aij)∑m=i−2i∑n=j−2jδmn(w(i−m)(j−n)−zij) ∂ J ∂ z i j = ∑ m = i − 2 i ∑ n = j − 2 j 2 δ m n w ( i − m ) ( j − n ) − z i j e x p ( − a i j ) = 2 e x p ( − a i j ) ∑ m = i − 2 i ∑ n = j − 2 j δ m n ( w ( i − m ) ( j − n ) − z i j )
grad_w
同理
∂J∂wij=∑m=0height∑n=0widthδmn∂amn∂wij ∂ J ∂ w i j = ∑ m = 0 h e i g h t ∑ n = 0 w i d t h δ m n ∂ a m n ∂ w i j
其中
∂amn∂wij=(−2)×wij−z(m+i)(n+j)∑2x=0∑2y=0D(wx,y,z(m+x),(n+y))=(−2)×wij−z(m+i)(n+j)exp(−aij) ∂ a m n ∂ w i j = ( − 2 ) × w i j − z ( m + i ) ( n + j ) ∑ x = 0 2 ∑ y = 0 2 D ( w x , y , z ( m + x ) , ( n + y ) ) = ( − 2 ) × w i j − z ( m + i ) ( n + j ) e x p ( − a i j )
則
∂J∂wij=∑m=0height∑n=0width(−2)δmnwij−z(m+i)(n+j)exp(−aij)=−2exp(−aij)∑m=0height∑n=0widthδmn(wij−z(m+i)(n+j)) ∂ J ∂ w i j = ∑ m = 0 h e i g h t ∑ n = 0 w i d t h ( − 2 ) δ m n w i j − z ( m + i ) ( n + j ) e x p ( − a i j ) = − 2 e x p ( − a i j ) ∑ m = 0 h e i g h t ∑ n = 0 w i d t h δ m n ( w i j − z ( m + i ) ( n + j ) )
模闆生成損失 R2 R 2
模闆生成損失 R2 R 2 可以通過maxpool的輸出,取 exp(−aij) e x p ( − a i j ) 獲得, 直接權重到反向傳播輸入的梯度資訊即可
projection
每5個epoch,需要把prototype的weight設定成訓練集中距離最近的patch.