問題:教師的logits進行訓練可能出現incorrect和overly uncertain的監督
解決: (1) Logits Adjustment(LA)
(2) Dynamic Temperature Distillation(DTD)
LA:針對錯誤判斷的訓練樣本,交換GT标簽和誤判标簽的logits值
DTD:一些uncertain soft target是因為過高的溫度值,是以可采用動态溫度計算soft target,
該溫度在訓練期間自适應更新----->學生模型在訓練中能夠獲得更多discriminative information,可視為線上的hard examples mining過程(OHEM:根據損失選擇樣本,重新送入網絡)。
文章提出的方法:
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsICM38FdsYkRGZkRG9lcvx2bjxiNx8VZ6l2cs0TPRRmbWNjW1EjMMBjVtJWd0ckW65UbM5WOHJWa5kHT20ESjBjUIF2X0hXZ0xCMx81dvRWYoNHLrdEZwZ1Rh5WNXp1bwNjW1ZUba9VZwlHdssmch1mclRXY39CXldWYtlWPzNXZj9mcw1ycz9WL49zZuBnL4QDOzUDN1QTM1ADOwEjMwIzLc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.png)
3.1回顧知識蒸餾:
公式(1):之前KD工作中為了将網絡的softmax輸出機率分布更加soft,便向softmax函數引入了溫度參數
ti 和si分别為教師模型和學生模型的logits
q和p分别為教師模型和學生模型的softmax機率分布。
公式(2):是蒸餾過程中的整體損失構成=學生損失+蒸餾損失;
學生損失是學生模型基于ground truth label進行訓練學習得到的損失,緩解學習到的 不正确的logits資訊;但這種緩解隻是輕微的,仍存在遺傳錯誤的問題。
文中驗證:統計在ResNet50指導下ResNet18在CIFAR-10和CIFAR-100的遺傳錯誤率分别有57%和42%。
3.2 Genetic errors and logits adjustment
新概念Genetic errors: 學生模型的錯誤預測與教師的錯誤預測一緻,稱為遺傳性錯誤。
提出LA,試圖fix教師的預測,對教師的softened logits執行函數A(·),修改後的損失仍用交叉熵:
A(·)的3個特點:
(1)fix錯誤的qt,而對正确的不做任何事,為保證訓練穩定性,所作的修改盡可能小一些。
(2)在學生訓練期間,從教師網絡得到的參數qt不可變,是以優化對象可用交叉熵表示,而不是用KL。
(3)交叉熵計算不用y,因為A(qt)是完全正确的。
3.2.1 Why not LSR?
Label Smooth Regularization (LSR)屬于LA最簡單的實作方式,但是有所限制。
LSR label:
類别數量k 樣本x 脈沖信号δ(·)
LSR丢棄教師預測的非true類别的機率,但這在KD中被證明是有幫助的。
是以,另一種簡單實作被提出:Probability Shift(PS). 機率轉移
思想:交換真實值标簽值(理論上最大值)與預測類别值(預測的最大值),以保證最大置信度落在真實值标簽。
Fig2. PD on 誤判樣本的soft target ,The sample is from CIFAR-100 training data, whose ground truth label is leopard but ResNet-50 teacher’s prediction is rabbit.The value of ”leopard” is still large, which indicates that the teacher does not go ridiculous. 轉換操作就是交換兩個類的值,得到一個leopard的最大預測值,rabbit的第二大預測值。
與LSR相比,PS保留了涉及微小機率的類别間差異,LSR則丢棄了大部分。不正确的預測類别往往與真實類别有一些相似的特征。也就是說,不正确的預測類别可能比其他類别包含更多資訊。該方法還保留了軟目标的數字分布,這對穩定訓練過程是有幫助的。
動态溫度蒸餾:
[30,31,44]研究表明學生可以從監督的不确定性中受益,但教師的過度不确定的預測也可能會影響性能。
下圖:蒸餾softmax的可視化
圖中可以看到随着溫度的升高,各個類之間的機率差異變小,而真實值是leopard,另外兩種kangroo和rabbit在訓練中就是幹擾項,是以為更好區分(擴大類間相似度),應當選擇更小些的溫度值。(但之前有提到過較高的溫度值可以讓softmax輸出更加soft,但這裡用較高的溫度會讓非真實類稱為訓練幹擾項,是以采用動态溫度DTD的思想)
DTD描述:這裡用的KL散度,而不是cross entropy,因為ptx是變化的,不是一個常量。
t0和β代表 基礎溫度和偏差,wx代表樣本x的批量歸一化權重,描述混淆的程度。當樣本x有些混淆且教師預測值不确定時,wx會增加。如此一來tx<t0,soft targets更加有區分性。
======》混亂的樣本會有更大的weights,更低的溫度值。---樣本之間就更加有區分度。
======》DTD更加關注那些confusing examples,就可以視為是一種hard examples mining.
文中提出兩種方法計算權重wx,:
一種是FLSW計算sample-wise 權重; 另一種是依據學生預測的最大輸出計算wx,稱為Confidence Weighted by Student Max(CWSM)
Focal loss style weights:
原來的focal loss:
p為一個樣本的分類分數。
本文方法中:學習難度可通過學生的logit v和教師的logit t之間的相似性來衡量。為簡便,将r設定為一個常量,得
(wx代表迷惑程度即難分類程度)
v·t∈[-1,1]是兩個分布的内積。當學生預測與教師的預測相差甚遠時,wx就會變大。
(回顧:A·B 内積計算是由第一個矩陣的每一行乘以第二個矩陣的每一列得到的)
Confidence Weighted by Student Max
根據學生歸一化的logits的最大值給樣本權重,在一定程度上可以反應樣本的學習情況。
學生模型通常對confusing 樣本有着不确定的預測,其logits的最大值也相應小一些,這裡計算wx公式描述為:
其中學生的logit v 應該是normalized的,vmax被視為代表學生對樣本的置信度 。低置信度的樣本有更高的權重,這些樣本的梯度在蒸餾過程中也貢獻的更多。
Compound loss function and algorithm複式損失函數和算法
結合LA和DTD,整體損失:
與公式(3)類似,但不同在于(10)采用sample-wise溫度來soften logits.
監督張量A(qtx)會随着學習情況不斷變化。
此外這裡沒必要使用真實值交叉熵,因為A(qtx)總是正确的。
實驗:
資料集:CIFAR-10, CIFAR-100, Tiny ImageNet
方法比較:标準蒸餾 (KD),注意力機制 (AT),神經元選擇性遷移 (NST)
标準蒸餾KD:
公式(2)中α=0.7,友善起見,用KL散度實作兩個分布之間的交叉熵。
AT:KD中引入注意力機制
NST:
注:MMD(Maximum Mean Discrepancy)
SP:Similarity Preserving Distillation
從特征相似性的角度提出了一種新的蒸餾損失,引導學生模仿樣本間的相似性,而不是教師對空間的邏輯.
其中b為batch size,||·||F是矩陣的Frobenius範數,矩陣中的元素的平方和再開方。對于向量而言就是L2距離。Gt和Gs分别為教師和學生模型certain layer的相似性矩陣。