Mean Teacher出自此文。本文所用代碼引用自此處。接下來我們以偏僞代碼的風格來通俗解釋Mean Teacher。
首先,Mean Teacher中有兩個網絡,一個稱為Teacher,一個稱為Student,其結構完全一緻,隻是網絡權重更新方法不同:
model = create_model() # Student Model
ema_model = create_model(ema=True) # Teacher Model (Equipped with EMA)
先暫時不管EMA是什麼意思。一般來講,在半監督中,每個輸入Batch包含一半已标注的圖像與一般未标注的圖像。首先,整個Batch會被送入Student Model中,得到一個預測結果。對于Batch中的已标注部分,利用結果與真值計算loss,進行梯度反傳,進而更新Student Model的參數,如下所示:
outputs = model(volume_batch) # 将圖像輸入Student中
supervised_loss = ce_loss(outputs[:args.labeled_bs], label_batch[:][:args.labeled_bs].long()) # 計算已标注部分的loss
而對于Batch中的未标注部分,其輸入Student Model也會得到一個結果(記為A),這個結果有什麼用呢?現在我們來看Teacher Model。具體來說,未标注的圖像會在加入随機噪聲後,會被送入Teacher Model中,得到一個預測結果(記為B):
那麼我們希望A與B的結果保持一緻,如下所示:
現在來回答兩個問題:
- Q1:EMA是什麼?Teacher模型不通過Loss反傳更新梯度,那麼其參數是怎麼更新的?
- A1:EMA即Exponential Moving Average,指數移動平均。通俗來講的話,Teacher模型的參數由Student模型過去一段時間的參數共同決定,可以通過拷貝Student模型的參數并計算以得到。這麼設計可以使Teacher模型反映Student在過去一段時間内的狀态。
- Q2:上面提到的Consistency為何會對半監督起到幫助?
- A2:有很多種了解。比方說,如果Teacher與Student能對相同的樣本得到一緻的結果,說明網絡目前的參數比較魯棒泛化——加噪前後的結果一緻,說明網絡不太可能overfit到一些特殊特征;在這種情況下網絡的預測結果一般是比較好的。