天天看點

center loss的完全了解以及實作

最近項目中需要 center loss 提升模型的效果,但是 center loss 的實作就有點不确定,看了很多的部落格,基本都是臆測,還是看源碼來的實在。

下面就大緻說下 center loss 的實作:

1、原理:

原理這塊大家可以參考别人的部落格,或者paper,這裡就簡單叙述下:讓得到全連接配接層向量距離對應類别中心的距離最小

2、問題

類别中心是動态變化的麼?如何進行變化?

(1)是每個epoch結束後使用所有的樣本重新聚類計算得到樣本中心麼?

(2)在每個batch内計算動态變化得到聚類中心

當然是第二種方式,第一種方式太過于直白,最大的問題就是更新的太滞後了,基本上業界沒有這樣用的。

那麼第二種方式該如何實作?每個batch内不一定包含所有的類别圖像,維護一個參數矩陣?如何初始化?如何得到類别中心點(聚類還是求均值?)?

3、具體的實作

确實需要一個參數矩陣來維護并更新我們得到的聚類中心,正常能想到的方式就是自定義一個layer,然後再layey種定義參數矩陣等等,最終加入模型進行訓練.

還有一種更為簡潔的方式就是使用 Embedding 層的方式進行輔助訓練,Embedding 層不僅僅可以實作一個次元的映射,而且最重要的是該層裡面也有參數,是一個可以被訓練的層,是以一切到這裡就可以結束了ÿ

繼續閱讀