天天看點

讓機器用人的方式識别圖像[codes]論文位址prototype layer模闆生成損失 R2 R 2 R_2projection

論文位址

This looks like that: deep learning for interpretable image recognition

prototype layer

prototype layer對應的公式如下:

gpj=maxz⃗ ∈patch(z)−log(||z⃗ −pj||22+ϵ) g p j = max z → ∈ p a t c h ( z ) − l o g ( | | z → − p j | | 2 2 + ϵ )

其中max層可以用現成的maxpool層來實作,下面重點讨論新的層實作

fz=−log(||z⃗ −pj||22+ϵ) f z = − l o g ( | | z → − p j | | 2 2 + ϵ )

這個層的輸入輸出的通道數可以不一緻,但是每個通道内的尺寸需要一緻. 為了後續讨論友善,約定以下符号

* in_data: (batchSize, inChNum, height, width), 輸入資料,即公式中的 z z

* weight: (outChNum, inChNum, 3, 3), filter權重, 此處假設采用3x3的filter

* out_data: (batchSize, outChNum, height, width), 輸出資料,迹公式中的f(z)f(z)

in_data (one sample)

z00 z 00 z01 z 01 z02 z 02
z10 z 10 z11 z 11 z12 z 12
z20 z 20 z21 z 21 z22 z 22
z30 z 30 z31 z 31 z32 z 32

weight (one sample and one output channel)

w00 w 00 w01 w 01
w10 w 10 w11 w 11
w20 w 20 w21 w 21

out_data (one sample and one output channnel)

a00 a 00 a01 a 01 a02 a 02
a10 a 10 a11 a 11 a12 a 12
a20 a 20 a21 a 21 a22 a 22

forward

按文中所述, 輸出 aij a i j 對應的輸入的patch的左上角是 zij z i j ,而不是常見的中心點.

fz=−log(||z⃗ −pj||22+ϵ) f z = − l o g ( | | z → − p j | | 2 2 + ϵ )

等價于

ai,j=−log∑m=02∑n=02D(wm,n,z(i+m),(j+n)) a i , j = − l o g ∑ m = 0 2 ∑ n = 0 2 D ( w m , n , z ( i + m ) , ( j + n ) )

其中

D(w,z)=(w−z)2 D ( w , z ) = ( w − z ) 2

PS: 考慮下為什麼是 w−z w − z 而不是 z−w z − w

backward

backward包括兩部分

* 權重w的梯度,用來調節w的值

* 輸入z的梯度,用來繼續反向傳播梯度資訊

假設整個網絡的損失函數 J J ,反向傳入的梯度資訊

∂J∂aij=δij∂J∂aij=δij

此處認為已知

grad_z

此時的目标函數用鍊式法則展開有

∂J∂zij=∑m=i−2i∑n=j−2jδmn∂amn∂zij ∂ J ∂ z i j = ∑ m = i − 2 i ∑ n = j − 2 j δ m n ∂ a m n ∂ z i j

其中

∂amn∂zij=2×w(i−m)(j−n)−zij∑2x=0∑2y=0D(wx,y,z(m+x),(n+y))=2×w(i−m)(j−n)−zijexp(−aij) ∂ a m n ∂ z i j = 2 × w ( i − m ) ( j − n ) − z i j ∑ x = 0 2 ∑ y = 0 2 D ( w x , y , z ( m + x ) , ( n + y ) ) = 2 × w ( i − m ) ( j − n ) − z i j e x p ( − a i j )

∂J∂zij=∑m=i−2i∑n=j−2j2δmnw(i−m)(j−n)−zijexp(−aij)=2exp(−aij)∑m=i−2i∑n=j−2jδmn(w(i−m)(j−n)−zij) ∂ J ∂ z i j = ∑ m = i − 2 i ∑ n = j − 2 j 2 δ m n w ( i − m ) ( j − n ) − z i j e x p ( − a i j ) = 2 e x p ( − a i j ) ∑ m = i − 2 i ∑ n = j − 2 j δ m n ( w ( i − m ) ( j − n ) − z i j )

grad_w

同理

∂J∂wij=∑m=0height∑n=0widthδmn∂amn∂wij ∂ J ∂ w i j = ∑ m = 0 h e i g h t ∑ n = 0 w i d t h δ m n ∂ a m n ∂ w i j

其中

∂amn∂wij=(−2)×wij−z(m+i)(n+j)∑2x=0∑2y=0D(wx,y,z(m+x),(n+y))=(−2)×wij−z(m+i)(n+j)exp(−aij) ∂ a m n ∂ w i j = ( − 2 ) × w i j − z ( m + i ) ( n + j ) ∑ x = 0 2 ∑ y = 0 2 D ( w x , y , z ( m + x ) , ( n + y ) ) = ( − 2 ) × w i j − z ( m + i ) ( n + j ) e x p ( − a i j )

∂J∂wij=∑m=0height∑n=0width(−2)δmnwij−z(m+i)(n+j)exp(−aij)=−2exp(−aij)∑m=0height∑n=0widthδmn(wij−z(m+i)(n+j)) ∂ J ∂ w i j = ∑ m = 0 h e i g h t ∑ n = 0 w i d t h ( − 2 ) δ m n w i j − z ( m + i ) ( n + j ) e x p ( − a i j ) = − 2 e x p ( − a i j ) ∑ m = 0 h e i g h t ∑ n = 0 w i d t h δ m n ( w i j − z ( m + i ) ( n + j ) )

模闆生成損失 R2 R 2

模闆生成損失 R2 R 2 可以通過maxpool的輸出,取 exp(−aij) e x p ( − a i j ) 獲得, 直接權重到反向傳播輸入的梯度資訊即可

projection

每5個epoch,需要把prototype的weight設定成訓練集中距離最近的patch.

繼續閱讀