天天看點

Preparing lessons: Improve knowledge distillation with Better supervision

問題:教師的logits進行訓練可能出現incorrect和overly uncertain的監督

解決: (1) Logits Adjustment(LA)

            (2) Dynamic Temperature Distillation(DTD)

LA:針對錯誤判斷的訓練樣本,交換GT标簽和誤判标簽的logits值

DTD:一些uncertain soft target是因為過高的溫度值,是以可采用動态溫度計算soft target,

該溫度在訓練期間自适應更新----->學生模型在訓練中能夠獲得更多discriminative information,可視為線上的hard examples mining過程(OHEM:根據損失選擇樣本,重新送入網絡)。

文章提出的方法:

Preparing lessons: Improve knowledge distillation with Better supervision

3.1回顧知識蒸餾:

Preparing lessons: Improve knowledge distillation with Better supervision

 公式(1):之前KD工作中為了将網絡的softmax輸出機率分布更加soft,便向softmax函數引入了溫度參數

Preparing lessons: Improve knowledge distillation with Better supervision

 ti 和si分别為教師模型和學生模型的logits

q和p分别為教師模型和學生模型的softmax機率分布。

公式(2):是蒸餾過程中的整體損失構成=學生損失+蒸餾損失;

學生損失是學生模型基于ground truth label進行訓練學習得到的損失,緩解學習到的 不正确的logits資訊;但這種緩解隻是輕微的,仍存在遺傳錯誤的問題。

文中驗證:統計在ResNet50指導下ResNet18在CIFAR-10和CIFAR-100的遺傳錯誤率分别有57%和42%。

 3.2 Genetic errors and logits adjustment

新概念Genetic errors: 學生模型的錯誤預測與教師的錯誤預測一緻,稱為遺傳性錯誤。

提出LA,試圖fix教師的預測,對教師的softened logits執行函數A(·),修改後的損失仍用交叉熵:

Preparing lessons: Improve knowledge distillation with Better supervision

 A(·)的3個特點:

(1)fix錯誤的qt,而對正确的不做任何事,為保證訓練穩定性,所作的修改盡可能小一些。

(2)在學生訓練期間,從教師網絡得到的參數qt不可變,是以優化對象可用交叉熵表示,而不是用KL。

(3)交叉熵計算不用y,因為A(qt)是完全正确的。

3.2.1 Why not LSR?

 Label Smooth Regularization (LSR)屬于LA最簡單的實作方式,但是有所限制。

LSR label:

Preparing lessons: Improve knowledge distillation with Better supervision

 類别數量k   樣本x    脈沖信号δ(·)

LSR丢棄教師預測的非true類别的機率,但這在KD中被證明是有幫助的。

是以,另一種簡單實作被提出:Probability Shift(PS). 機率轉移

思想:交換真實值标簽值(理論上最大值)與預測類别值(預測的最大值),以保證最大置信度落在真實值标簽。

Preparing lessons: Improve knowledge distillation with Better supervision

Fig2. PD on 誤判樣本的soft target ,The sample is from CIFAR-100 training data, whose ground truth label is leopard but ResNet-50 teacher’s prediction is rabbit.The value of ”leopard” is still large, which indicates that the teacher does not go ridiculous. 轉換操作就是交換兩個類的值,得到一個leopard的最大預測值,rabbit的第二大預測值。

與LSR相比,PS保留了涉及微小機率的類别間差異,LSR則丢棄了大部分。不正确的預測類别往往與真實類别有一些相似的特征。也就是說,不正确的預測類别可能比其他類别包含更多資訊。該方法還保留了軟目标的數字分布,這對穩定訓練過程是有幫助的。

動态溫度蒸餾:

[30,31,44]研究表明學生可以從監督的不确定性中受益,但教師的過度不确定的預測也可能會影響性能。

下圖:蒸餾softmax的可視化

圖中可以看到随着溫度的升高,各個類之間的機率差異變小,而真實值是leopard,另外兩種kangroo和rabbit在訓練中就是幹擾項,是以為更好區分(擴大類間相似度),應當選擇更小些的溫度值。(但之前有提到過較高的溫度值可以讓softmax輸出更加soft,但這裡用較高的溫度會讓非真實類稱為訓練幹擾項,是以采用動态溫度DTD的思想)

Preparing lessons: Improve knowledge distillation with Better supervision

 DTD描述:這裡用的KL散度,而不是cross entropy,因為ptx是變化的,不是一個常量。

Preparing lessons: Improve knowledge distillation with Better supervision
Preparing lessons: Improve knowledge distillation with Better supervision

t0和β代表 基礎溫度和偏差,wx代表樣本x的批量歸一化權重,描述混淆的程度。當樣本x有些混淆且教師預測值不确定時,wx會增加。如此一來tx<t0,soft targets更加有區分性。

======》混亂的樣本會有更大的weights,更低的溫度值。---樣本之間就更加有區分度。

======》DTD更加關注那些confusing examples,就可以視為是一種hard examples mining.

文中提出兩種方法計算權重wx,:

一種是FLSW計算sample-wise 權重; 另一種是依據學生預測的最大輸出計算wx,稱為Confidence Weighted by Student Max(CWSM)

Focal loss style weights:

原來的focal loss:

Preparing lessons: Improve knowledge distillation with Better supervision

p為一個樣本的分類分數。

本文方法中:學習難度可通過學生的logit v和教師的logit t之間的相似性來衡量。為簡便,将r設定為一個常量,得

(wx代表迷惑程度即難分類程度)

Preparing lessons: Improve knowledge distillation with Better supervision

v·t∈[-1,1]是兩個分布的内積。當學生預測與教師的預測相差甚遠時,wx就會變大。 

(回顧:A·B 内積計算是由第一個矩陣的每一行乘以第二個矩陣的每一列得到的)

Confidence Weighted by Student Max

根據學生歸一化的logits的最大值給樣本權重,在一定程度上可以反應樣本的學習情況。

學生模型通常對confusing 樣本有着不确定的預測,其logits的最大值也相應小一些,這裡計算wx公式描述為:

Preparing lessons: Improve knowledge distillation with Better supervision

其中學生的logit v 應該是normalized的,vmax被視為代表學生對樣本的置信度 。低置信度的樣本有更高的權重,這些樣本的梯度在蒸餾過程中也貢獻的更多。

Compound loss function and algorithm複式損失函數和算法

結合LA和DTD,整體損失:

Preparing lessons: Improve knowledge distillation with Better supervision

 與公式(3)類似,但不同在于(10)采用sample-wise溫度來soften logits.

監督張量A(qtx)會随着學習情況不斷變化。

此外這裡沒必要使用真實值交叉熵,因為A(qtx)總是正确的。

Preparing lessons: Improve knowledge distillation with Better supervision

 實驗:

資料集:CIFAR-10, CIFAR-100, Tiny ImageNet

方法比較:标準蒸餾 (KD),注意力機制 (AT),神經元選擇性遷移 (NST)

标準蒸餾KD:

公式(2)中α=0.7,友善起見,用KL散度實作兩個分布之間的交叉熵。

AT:KD中引入注意力機制

Preparing lessons: Improve knowledge distillation with Better supervision
Preparing lessons: Improve knowledge distillation with Better supervision

NST:

注:MMD(Maximum Mean Discrepancy)

Preparing lessons: Improve knowledge distillation with Better supervision

 SP:Similarity Preserving Distillation

 從特征相似性的角度提出了一種新的蒸餾損失,引導學生模仿樣本間的相似性,而不是教師對空間的邏輯.

Preparing lessons: Improve knowledge distillation with Better supervision

其中b為batch size,||·||F是矩陣的Frobenius範數,矩陣中的元素的平方和再開方。對于向量而言就是L2距離。Gt和Gs分别為教師和學生模型certain layer的相似性矩陣。

Preparing lessons: Improve knowledge distillation with Better supervision
Preparing lessons: Improve knowledge distillation with Better supervision
Preparing lessons: Improve knowledge distillation with Better supervision
Preparing lessons: Improve knowledge distillation with Better supervision
Preparing lessons: Improve knowledge distillation with Better supervision
Preparing lessons: Improve knowledge distillation with Better supervision

繼續閱讀