天天看點

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

文章目錄

    • 一、背景
    • 二、動機
    • 三、方法
      • 3.1 回顧 Spatial Distillation
      • 3.2 Channel-wise Distillation
    • 四、效果
    • 五、訓練和測試
    • 六、代碼解析

論文連結:https://arxiv.org/pdf/2011.13256.pdf

代碼連結:https://github.com/irfanICMLL/TorchDistiller

MMDetection:https://github.com/pppppM/mmdetection-distiller

MMSegmentation:https://github.com/pppppM/mmsegmentation-distiller

一、背景

密集預測是計算機視覺的一個重要基礎,如語義分割和目标檢測,這些任務需要學習特征的良好表達。目前較好的方法都需要大量的計算資源,難以在移動端部署。

分類任務上的蒸餾起到了明顯的效果[16, 2],但沒法直接用到語義分割,因為将逐個像素分類的任務嚴格對齊會導緻 student 模型過度學習 teacher 的輸出,無法獲得最優結果。

于是有一些方法 [25,24,18] 聚焦于加強不同 spatial 的聯系,如圖2a:

  • 首先,每個空間位置上的特征圖都被歸一化
  • 然後,通過聚合不同空間位置的子集來分析一些特定任務的關系,如 pair-wise 關系[25,35],和 inter-class 關系[18]。

二、動機

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction
  • Spatial distillation: 空間方向的蒸餾,可以了解成對所有通道的相同位置的點做歸一化,然後讓學生網絡學習這個歸一化後的分布,可以了解成對類别的蒸餾。
  • Channel distillation: 通道方向的蒸餾,可以了解成對單個通道内做歸一化,然後讓學生網絡學習這個歸一化後的分布,可以了解成對位置的蒸餾。

雖然上面的這些方法比逐點對比好一些,但特征圖中的每個空間位置都對 konwledge transfering 貢獻相同,這樣可能從 teacher 帶來一些備援資訊。

還有一些方法使用了 channel 蒸餾,[50] 提出了将每個 channel 的 activation 聚合到一個聚合向量,這樣更有利于 image-level 的分類,但不适合于需要空間資訊的密集預測。

是以本文通過歸一化每個 channel 的特征圖來得到 soft probability map,如圖2b,然後最小化兩個網絡的 channel-wise probability map 的 asymmetry Kullback-Leibler(KL)散度,該KL 散度也就是 teacher 和 student 網絡的每個channel間的分布。一個例子如圖2c,每個 channel 的 activation map 會更關注于每個 channel 中的突出區域,也就是每個類别的突出區域,而這些區域恰恰是對密集預測很有用的。

  • COCO 上使用 RetinaNet(res50)提了3.4% mAP
  • Cityscape 上使用 PSPNet 提了5.81% mIoU

三、方法

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

*The activation values in this work include the final logits and the inner

feature maps

3.1 回顧 Spatial Distillation

通常的蒸餾方法是使用 point-wise 對齊的方式,形式如下:

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

3.2 Channel-wise Distillation

為了更好的利用每個 channel 中的知識,作者提出了對 teacher 和 student 網絡的對應 channel activation 進行 softly align。

  • 首先,将每個 channel 的 activation 轉換成機率分布,即可以使用機率分布度量方式來衡量其差異,如 KL 散度。如圖2c所示,每個 channel 的 activation 都趨向于對每個類别的突出特征進行編碼
  • 然後,使用訓練好的 teacher 模型來得到預測的 clear category-specific mask,如圖1右側所示,讓 student 網絡從 teacher 網絡中學習知識

Channel-wise distillation loss 如下:

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction
  • y T y^T yT:teacher 的 activation map
  • y S y^S yS:student 的 activation map
  • ϕ \phi ϕ:将 activation value 轉換成機率分布的方式,如下所示,使用這種 softmax 歸一化,就可以消除大網絡和小網絡之間的數值大小之差。
    • c = 1 , 2 , . . . , C c = 1,2,...,C c=1,2,...,C :表示 channel
    • i i i : channel 中像素位置
    • T T T:溫度參數,也是一個超參數,當 T T T 越大,輸出的機率分布越 soft,即每個channel關注的空間區域就越大,-
【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction
  • 如何解決 teacher 和 student 的 channel 個數不一緻: 使用 1x1 卷積對 student 網絡個數進行上采樣
  • Φ \Phi Φ:用來衡量 teacher 和 student 的每個 channel 的機率分布的差異,本文使用 KL 散度
    • KL 散度是一個不對稱的衡量方式
    • 當 ϕ ( y c , i T ) \phi(y_{c,i}^T) ϕ(yc,iT​) 越大, ϕ ( y c , i S ) \phi(y_{c,i}^S) ϕ(yc,iS​) 也要越大,來最小化 KL 散度
    • 當 ϕ ( y c , i T ) \phi(y_{c,i}^T) ϕ(yc,iT​) 越小,則 KL 散度确不會讓 ϕ ( y c , i S ) \phi(y_{c,i}^S) ϕ(yc,iS​) 一直變小
    • 是以,student 網絡會更趨向于在前景突出特征的位置學習 teacher 網絡的分布,teacher 網絡分布的背景區域對學習産生的影響很小
      【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

四、效果

T = 4 T=4 T=4

logits map: α = 3 \alpha=3 α=3

feature map: α = 50 \alpha=50 α=50

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction
【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

消融實驗:

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction
【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction
【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction
【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

五、訓練和測試

以 mmsegmentation 的訓練代碼為例

1、安裝 mmsegmentation

2、軟連接配接資料:

cd mmsegmentation_distiller
mkdir data
ln -s cityscapes .
           

3、下載下傳訓練好的大模型 pspnet_r101,并放到 pretrained_model下,下載下傳模型路徑

4、訓練和測試

# 單 GPU 訓練
python tools/train.py configs/distiller/cwd/cwd_psp_r101-d8_distill_psp_r18_d8_512_1024_80k_cityscapes.py
# 訓練教師網絡
python tools/train.py configs/ocrnet/ocrnet_hr48_512x1024_80k_cityscapes.py

# 多 GPU 訓練
bash tools/dist_train.sh configs/distillers/cwd/cwd_psp_r101-d8_distill_psp_r18_d8_512_1024_80k_cityscapes.py 8

#單 GPU 測試
python tools/test.py configs/distillers/cwd/cwd_psp_r101-d8_distill_psp_r18_d8_512_1024_80k_cityscapes.py $CHECKPOINT --eval mIoU

#多 GPU 測試
bash tools/dist_test.sh configs/distillers/cwd/cwd_psp_r101-d8_distill_psp_r18_d8_512_1024_80k_cityscapes.py $CHECKPOINT 8 --eval mIoU

           

5、了解 config

config/distiller/cwd/cwd_psp_r101-d8_distill_psp_d8_512_1024_80k_cityscapes.py
           
_base_ = [
     '../../_base_/datasets/cityscapes.py',
    '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_80k.py'
]


find_unused_parameters=True
weight=5.0
tau=1.0
distiller = dict(
    type='SegmentationDistiller',
    teacher_pretrained = 'pretrained_model/pspnet_r101b-d8_512x1024_80k_cityscapes_20201226_170012-3a4d38ab.pth',
    distill_cfg = [ dict(student_module = 'decode_head.conv_seg',
                         teacher_module = 'decode_head.conv_seg',
                         output_hook = True,
                         methods=[dict(type='ChannelWiseDivergence',
                                       name='loss_cwd',
                                       student_channels = 19,
                                       teacher_channels = 19,
                                       tau = tau,
                                       weight =weight,
                                       )
                                ]
                        ),
                    
                   ]
    )

student_cfg = 'configs/pspnet/pspnet_r18-d8_512x1024_80k_cityscapes.py'
teacher_cfg = 'configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py'
           
  • 教師網絡

    decode_head.conv_seg

$ p teacher_modules['decode_head.conv_seg']
>>> 
Conv2d(512, 19, kernel_size=(1, 1), stride=(1, 1))
           
  • 學生網絡

    decode_head.conv_seg

$ p student_modules['decode_head.conv_seg']
>>> 
Conv2d(128, 19, kernel_size=(1, 1), stride=(1, 1))
           

6、psp 教師網絡解碼頭結構:

(decode_head): PSPHead(
      input_transform=None, ignore_index=255, align_corners=False
      (loss_decode): CrossEntropyLoss()
      (conv_seg): Conv2d(512, 19, kernel_size=(1, 1), stride=(1, 1))
      (dropout): Dropout2d(p=0.1, inplace=False)
      (psp_modules): PPM(
        (0): Sequential(
          (0): AdaptiveAvgPool2d(output_size=1)
          (1): ConvModule(
            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activate): ReLU(inplace=True)
          )
        )
        (1): Sequential(
          (0): AdaptiveAvgPool2d(output_size=2)
          (1): ConvModule(
            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activate): ReLU(inplace=True)
          )
        )
        (2): Sequential(
          (0): AdaptiveAvgPool2d(output_size=3)
          (1): ConvModule(
            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activate): ReLU(inplace=True)
          )
        )
        (3): Sequential(
          (0): AdaptiveAvgPool2d(output_size=6)
          (1): ConvModule(
            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activate): ReLU(inplace=True)
          )
        )
      )
      (bottleneck): ConvModule(
        (conv): Conv2d(4096, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
    (auxiliary_head): FCNHead(
      input_transform=None, ignore_index=255, align_corners=False
      (loss_decode): CrossEntropyLoss()
      (conv_seg): Conv2d(256, 19, kernel_size=(1, 1), stride=(1, 1))
      (dropout): Dropout2d(p=0.1, inplace=False)
      (convs): Sequential(
        (0): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
      )
    )
  )
           

7、psp 學生網絡解碼頭結構:

(decode_head): PSPHead(
      input_transform=None, ignore_index=255, align_corners=False
      (loss_decode): CrossEntropyLoss()
      (conv_seg): Conv2d(128, 19, kernel_size=(1, 1), stride=(1, 1))
      (dropout): Dropout2d(p=0.1, inplace=False)
      (psp_modules): PPM(
        (0): Sequential(
          (0): AdaptiveAvgPool2d(output_size=1)
          (1): ConvModule(
            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activate): ReLU(inplace=True)
          )
        )
        (1): Sequential(
          (0): AdaptiveAvgPool2d(output_size=2)
          (1): ConvModule(
            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activate): ReLU(inplace=True)
          )
        )
        (2): Sequential(
          (0): AdaptiveAvgPool2d(output_size=3)
          (1): ConvModule(
            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activate): ReLU(inplace=True)
          )
        )
        (3): Sequential(
          (0): AdaptiveAvgPool2d(output_size=6)
          (1): ConvModule(
            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activate): ReLU(inplace=True)
          )
        )
      )
      (bottleneck): ConvModule(
        (conv): Conv2d(1024, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
    (auxiliary_head): FCNHead(
      input_transform=None, ignore_index=255, align_corners=False
      (loss_decode): CrossEntropyLoss()
      (conv_seg): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1))
      (dropout): Dropout2d(p=0.1, inplace=False)
      (convs): Sequential(
        (0): ConvModule(
          (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
      )
    )
  )
  (distill_losses): ModuleDict(
    (loss_cwd): ChannelWiseDivergence()
  )
)
           

這裡的

decode_head.seg_conv

其實是最後一層的輸出,即 PSP 頭輸出的最終結果,每個通道表示一個類别目标的預測。

8、如何修改為其他網絡結構的蒸餾

這裡以 OCR 網絡為例,psp 中是使用網絡的

decode_head.seg_conv

作為輸入的,我們首先需要看一下 OCR 網絡的

decode_head

結構,然後也取最後一層的輸出,即最後一層頭的

seg_conv

作為蒸餾的輸入,這裡以 hr48 作為教師網絡,hr18s作為學生網絡:

教師網絡

decode_head

ModuleList(
  (0): FCNHead(
    input_transform=resize_concat, ignore_index=255, align_corners=False
    (loss_decode): CrossEntropyLoss()
    (conv_seg): Conv2d(270, 19, kernel_size=(1, 1), stride=(1, 1))
    (convs): Sequential(
      (0): ConvModule(
        (conv): Conv2d(270, 270, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(270, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
  )
  (1): OCRHead(
    input_transform=resize_concat, ignore_index=255, align_corners=False
    (loss_decode): CrossEntropyLoss()
    (conv_seg): Conv2d(512, 19, kernel_size=(1, 1), stride=(1, 1))
    (object_context_block): ObjectAttentionBlock(
      (key_project): Sequential(
        (0): ConvModule(
          (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (1): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
      )
      (query_project): Sequential(
        (0): ConvModule(
          (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (1): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
      )
      (value_project): ConvModule(
        (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
      (out_project): ConvModule(
        (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
      (bottleneck): ConvModule(
        (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
    (spatial_gather_module): SpatialGatherModule()
    (bottleneck): ConvModule(
      (conv): Conv2d(270, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
  )
)
           

基于此,OCR 網絡的蒸餾輸入:

  • 教師網絡
$ p teacher_modules['decode_head.1.conv_seg']
>>>
Conv2d(512, 19, kernel_size=(1, 1), stride=(1, 1))
           
  • 學生網絡
$ p student_modules['decode_head.1.conv_seg']
>>>
Conv2d(512, 19, kernel_size=(1, 1), stride=(1, 1))
           

是以隻需要修改config即可,大模型是在mmsegmentation 官方代碼中下載下傳的,最終config如下:

_base_ = [
     '../../_base_/datasets/cityscapes.py',
    '../../_base_/default_runtime.py', '../../_base_/schedules/schedule_80k.py'
]


find_unused_parameters=True
weight=5.0
tau=1.0
distiller = dict(
    type='SegmentationDistiller',
    teacher_pretrained = 'pretrained_model/ocrnet_hr48_512x1024_160k_cityscapes.pth',
    distill_cfg = [ dict(student_module = 'decode_head.1.conv_seg',
                         teacher_module = 'decode_head.1.conv_seg',
                         output_hook = True,
                         methods=[dict(type='ChannelWiseDivergence',
                                       name='loss_cwd',
                                       student_channels = 19,
                                       teacher_channels = 19,
                                       tau = tau,
                                       weight =weight,
                                       )
                                ]
                        ),
                    
                   ]
    )

student_cfg = 'configs/ocrnet/ocrnet_hr18s_512x1024_80k_cityscapes.py'
teacher_cfg = 'configs/ocrnet/ocrnet_hr48_512x1024_160k_cityscapes.py'
           

代碼訓練:

python tools/train.py configs/distiller/cwd/cwd_ocr_hr48-d8_distill_ocr_hr18s-d8_512_1024_80k_cityscapes.py
           

訓練結果記錄:

cityscapes/ val /512x1024/ 80k iter/

教師網絡結構 mIoU 學生網絡結構 mIoU(蒸餾) mIoU(未蒸餾)
psp_r101 (272.4M) 79.74 psp_r18 (51.2M) 74.86
ocr_hr48 (282.2M) 81.35 ocr_hr18s (25.8M) 79.68 77.29

六、代碼解析

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

如果沒有 distiller config 的話,則會按照正常訓練方式訓練,distiller config 如下:

distiller_cfg = cfg.get('distiller', None)
$ p disstiller_cfg
>>>
{'type': 'SegmentationDistiller', 'teacher_pretrained': 'pretrained_model/ocrnet_hr48_512x1024_160k_cityscapes.pth', 
'distill_cfg': [{'student_module': 'decode_head.1.conv_seg', 'teacher_module': 'decode_head.1.conv_seg', 
'output_hook': True, 'methods': [{'type': 'ChannelWiseDivergence', 'name': 'loss_cwd', 
'student_channels': 19, 'teacher_channels': 19, 'tau': 1.0, 'weight': 5.0}]}]}
           

使用

Config.fromfile()

即可把

config

檔案中的内容拿出來:

teacher_cfg = Config.fromfile(cfg.teacher_cfg)
student_cfg = Config.fromfile(cfg.student_cfg)
           

訓練的時候使用的是 student 模型的

train_cfg

test_cfg

model = build_distiller(cfg.distiller,teacher_cfg,student_cfg,
         train_cfg=student_cfg.get('train_cfg'), 
         test_cfg=student_cfg.get('test_cfg'))
           

蒸餾的訓練方式和普通的訓練方式不同之一:optimezier 優化的參數不同,蒸餾的話,隻有student 的參數和蒸餾 loss 的參數參與訓練。

# build runner
distiller_cfg = cfg.get('distiller',None)
if distiller_cfg is None:
    optimizer = build_optimizer(model, cfg.optimizer)
else:
	# base_parameters() 在 segmentation_distiller.py line 69
	# base_parameters() 包括 student 和 distill_loss
    optimizer = build_optimizer(model.module.base_parameters(), cfg.optimizer)
           

可以使用這樣的方式來檢視不需要參數訓練的參數:

# pytorch 中需要訓練的參數
model.named_parameters()
# 不需要參數訓練的參數
model.named_buffers()
           

pytorch 可以使用

register_buffer()

來使得該參數不參與訓練

# name 是名字, 參數是登記的不參與訓練的參數
register_buffer(name, 參數)
           
buffer_key = [k for k,v in self.named_buffers()]
>>>
['student_decode_head_1_conv_seg', 'teacher_decode_head_1_conv_seg', 'teacher.backbone.bn1.running_mean', 'teacher.backbone.bn1.running_var', 'teacher.backbone.bn1.num_batches_tracked', 'teacher.backbone.bn2.running_mean', 'teacher.backbone.bn2.running_var', 'teacher.backbone.bn2.num_batches_tracked', ...
           

蒸餾的訓練方法:分兩步,第一步計算不參與蒸餾的層的 loss,然後計算參與蒸餾的層的loss

mmseg/distillation/distillers/segmentation_distiller.py
           
def forward_train(self, img, img_metas, gt_semantic_seg):
    with torch.no_grad():
        self.teacher.eval()
        teacher_loss = self.teacher.forward_train(img, img_metas, gt_semantic_seg) # mmseg/models/segmentors/encoder_decoder.py(136)forward_train()
       
    student_loss = self.student.forward_train(img, img_metas, gt_semantic_seg)
    # 整體loss
    # {'decode_0.loss_seg': tensor(1.1701, device='cuda:0', grad_fn=<MulBackward0>), 'decode_0.acc_seg': tensor([3.7306], device='cuda:0'), \
    # 'decode_1.loss_seg': tensor(2.9231, device='cuda:0', grad_fn=<MulBackward0>), 'decode_1.acc_seg': tensor([6.0701], device='cuda:0')}
    
    buffer_dict = dict(self.named_buffers())  # named_buffers() 檢視網絡中不需要更新的參數, parameters()檢視網絡中需要更新的參數
    for item_loc in self.distill_cfg:
        student_module = 'student_' + item_loc.student_module.replace('.','_') # 'student_decode_head_1_conv_seg'
        teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_') # 'teacher_decode_head_1_conv_seg'
        # 下面這兩步是關鍵,提取的是教師網絡和學生網絡的輸入 decode_head 之前的輸出,如下圖所示
        student_feat = buffer_dict[student_module] # [b, 19, 128 256]
        teacher_feat = buffer_dict[teacher_module] # [b, 19, 128 256]
        for item_loss in item_loc.methods: # item_loc.methods: [{'type': 'ChannelWiseDivergence', 'name': 'loss_cwd', 'student_channels': 19, 'teacher_channels': 19, 'tau': 1.0, 'weight': 5.0}]
            loss_name = item_loss.name     # 'loss_cwd'
            student_loss[ loss_name] = self.distill_losses[loss_name](student_feat,teacher_feat)
            # 增加了蒸餾 loss 後的loss: 
            # {'decode_0.loss_seg': tensor(1.1701, device='cuda:0', grad_fn=<MulBackward0>), 'decode_0.acc_seg': tensor([3.7306], device='cuda:0'),
            # 'decode_1.loss_seg': tensor(2.9231, device='cuda:0', grad_fn=<MulBackward0>), 'decode_1.acc_seg': tensor([6.0701], device='cuda:0'), 
            # 'loss_cwd': tensor(51.9439, device='cuda:0', grad_fn=<DivBackward0>)}
    
    return student_loss
           

下面這兩組特征的特征圖如下圖所示,學生網絡是第一次疊代的特征圖,還沒有學到任何特征

student_feat = buffer_dict[student_module] # [b, 19, 128 256]
teacher_feat = buffer_dict[teacher_module] # [b, 19, 128 256]
           

teacher_feat:

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

student_feat:

【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

看一下這兩個特征是怎麼來的,這裡是使用 hook 來擷取這兩層的輸出特征來得到的這兩組特征,每次執行個體化SegmentationDistiller 這個類的時候,其 init 裡邊都會走一遍特征注冊的過程,保證每次疊代後的特征放入 hook 裡邊:

hook 分為兩種:

  • register_forward_hook(hook)

  • register_backward_hook(hook)

hook 的作用是擷取某些變量的中間結果,因為pytorch會自動舍棄圖計算的中間結果,是以想要擷取這些數值就需要使用 hook 函數,hook 函數在使用後需要及時删除,避免每次都運作其增加負載。

# 這裡寫了一個注冊的 hook
def regitster_hooks(student_module,teacher_module):
    def hook_teacher_forward(module, input, output):
    		# 這裡的 input 和 output 是這層的輸入和輸出
        	self.register_buffer(teacher_module,output) # 通過register_buffer()登記過的張量:會自動成為模型中的參數,随着模型移動(gpu/cpu)而移動,但是不會随着梯度進行更新。
    def hook_student_forward(module, input, output):
            self.register_buffer( student_module,output )
    return hook_teacher_forward,hook_student_forward

for item_loc in distill_cfg:
    
    student_module = 'student_' + item_loc.student_module.replace('.','_') # 'student_decode_head_1_conv_seg'
    teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_') # 'teacher_decode_head_1_conv_seg'
    # 這裡進行

    hook_teacher_forward,hook_student_forward = regitster_hooks(student_module ,teacher_module )
    teacher_modules[item_loc.teacher_module].register_forward_hook(hook_teacher_forward)
    student_modules[item_loc.student_module].register_forward_hook(hook_student_forward)
           

register_forward_hook(hook) 作用就是(假設想要conv2層),那麼就是根據 model(該層),該層input,該層output,可以将 output擷取。

register_forward_hook(hook) 最大的作用也就是當訓練好某個model,想要展示某一層對最終目标的影響效果。

求loss的方法:

import torch.nn as nn
import torch.nn.functional as F
import torch

from .utils import weight_reduce_loss
from ..builder import DISTILL_LOSSES


@DISTILL_LOSSES.register_module()
class ChannelWiseDivergence(nn.Module):

    """PyTorch version of `Channel-wise Distillation for Semantic Segmentation
     <https://arxiv.org/abs/2011.13256>`_.
   
    Args:
        student_channels(int): Number of channels in the student's feature map.
        teacher_channels(int): Number of channels in the teacher's feature map.
        name(str): 
        tau (float, optional): Temperature coefficient. Defaults to 1.0.
        weight (float, optional): Weight of loss.Defaults to 1.0.
        
    """
    def __init__(self,
                 student_channels,
                 teacher_channels,
                 name,
                 tau=1.0,
                 weight=1.0,
                 ):
        super(ChannelWiseDivergence, self).__init__()
        self.tau = tau
        self.loss_weight = weight
    
        if student_channels != teacher_channels:
            self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.align = None


    def forward(self,
                preds_S,
                preds_T):
        """Forward function."""
        assert preds_S.shape[-2:] == preds_T.shape[-2:],'the output dim of teacher and student differ'
        N,C,W,H = preds_S.shape  # [2, 19, 128, 256]

        if self.align is not None:
            preds_S = self.align(preds_S)

        softmax_pred_T = F.softmax(preds_T.view(-1,W*H)/self.tau, dim=1)
        softmax_pred_S = F.softmax(preds_S.view(-1,W*H)/self.tau, dim=1)
        
        logsoftmax = torch.nn.LogSoftmax(dim=1)
        loss = torch.sum( - softmax_pred_T * logsoftmax(preds_S.view(-1,W*H)/self.tau)) * (self.tau ** 2)
        return self.loss_weight * loss / (C * N)
           
【知識蒸餾】ICCV21_Channel-wise Knowledge Distillation for Dense Prediction

這裡 KL 散度公式如上,展開後是這樣的:

D K L = ∑ p l o g p − p l o g q = ∑ T l o g T − T l o g S D_{KL} = \sum p logp-plogq=\sum TlogT-TlogS DKL​=∑plogp−plogq=∑TlogT−TlogS

前一項實際上是教師網絡的輸出,是固定不變的,是以最終的形式變成了 ∑ − T l o g S \sum-TlogS ∑−TlogS,也就是上面的代碼中的形式。

這裡以 OCR 為例解釋一下 loss 的組成:FCN loss + OCR loss + distillation loss

1、原始loss的計算:

  • OCR 是 cascade_docode_head,因為其解碼頭由 FCN 和 OCR 組成
  • FCN 的輸入是backbone的輸出,FCN 拿到一組 backbone 的輸出(有四組不同大小的特征圖構成,通道數共為270),然後輸出成 [N, 19, 128, 256] 的特征圖進行loss計算,這裡就是總loss中的

    'decode_0.loss_seg'

是以,在

segmentation_distiller.py

中計算原本的 loss 的時候,loss 會找到

mmseg/models/segmentors/cascade_encoder_decoder.py

中來計算前向傳播的loss:

def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
    """Run forward function and calculate loss for decode head in
    training."""
    losses = dict()
    # 先計算 decode_head[0] 的 loss,即 FPN 的 loss
    # 第一個 decode_head 走的是 cascade_head.py 的 forward_train 的過程
    loss_decode = self.decode_head[0].forward_train(x, img_metas, gt_semantic_seg, self.train_cfg)
	# loss_decode: {'loss_seg': tensor(1.1506, device='cuda:0', grad_fn=<MulBackward0>), 'acc_seg': tensor([1.5568], device='cuda:0')}
	
    losses.update(add_prefix(loss_decode, 'decode_0'))
    # loss: {'decode_0.loss_seg': tensor(1.1506, device='cuda:0', grad_fn=<MulBackward0>), 'decode_0.acc_seg': tensor([1.5568], device='cuda:0')}

    for i in range(1, self.num_stages): # config/models/ocrnet_hr18.py 中寫了 num_stage=2
        # forward test again, maybe unnecessary for most methods.
        
        # prev_outputs 是将 backbone 的輸出又走了一遍 FPN 得到的輸出,即 decode_head[0] 的輸出 [N, 19, 128, 256]
        prev_outputs = self.decode_head[i - 1].forward_test(x, img_metas, self.test_cfg)
        
        # 然後将 FPN 的輸出作為 loss 的輸入
        # 第二個及之後的 decode_heads 都會走 cascade_decode_head 的 forward_train,走到 ocr_head.py 中去
        # mmseg/models/decode_heads/cascade_decode_head.py # line 18
        # 這裡的 x 是 backbone的輸出(270維),prev_outputs 是 FPN 的輸出
        # OCRnet 會利用backbone 的輸出和 FPN 的輸出,做一個自己的注意力操作,得到 [N, 19, 128, 256] 的輸出,然後和真值做 loss
        loss_decode = self.decode_head[i].forward_train(x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
        losses.update(add_prefix(loss_decode, f'decode_{i}'))
        # {'decode_0.loss_seg': tensor(1.1506, device='cuda:0', grad_fn=<MulBackward0>), 'decode_0.acc_seg': tensor([1.5568], device='cuda:0'), 'decode_1.loss_seg': tensor(2.8385, device='cuda:0', grad_fn=<MulBackward0>), 'decode_1.acc_seg': tensor([1.2970], device='cuda:0')}
    return losses
           
# mmseg/models/decode_heads/decode_head.py # line 170
# decode_head[0] 的計算 loss
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
	# inputs.shape [2, 19, 128, 256]
	# 
    seg_logits = self.forward(inputs)
    losses = self.losses(seg_logits, gt_semantic_seg)
    return losses
           
# mmseg/models/decode_heads/cascade_decode_head.py # line 18
# decode_head[1] 及之後 head 的計算 loss
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
                  train_cfg):
    seg_logits = self.forward(inputs, prev_output)
    losses = self.losses(seg_logits, gt_semantic_seg)
    return losses
           
# mmseg/models/decode_heads/decode_head.py
@force_fp32(apply_to=('seg_logit', ))
def losses(self, seg_logit, seg_label):
    """Compute segmentation loss."""
    loss = dict()
    # 先把預測的 128x256 的結果上采樣到 512x1024的,和真值大小一樣
    seg_logit = resize(
        input=seg_logit,
        size=seg_label.shape[2:],
        mode='bilinear',
        align_corners=self.align_corners)
    if self.sampler is not None:
        seg_weight = self.sampler.sample(seg_logit, seg_label)
    else:
        seg_weight = None
    seg_label = seg_label.squeeze(1)
    # 進入 cross_entropy_loss # mmseg/models/losses/cross_entropy_loss.py
    loss['loss_seg'] = self.loss_decode(
        seg_logit,
        seg_label,
        weight=seg_weight,
        ignore_index=self.ignore_index)
    loss['acc_seg'] = accuracy(seg_logit, seg_label)
    return loss
    # 得到 'acc_seg' 和 'loss_seg'
           

2、蒸餾 loss 的計算:計算

def forward(self, preds_S, preds_T):
    """Forward function."""
    assert preds_S.shape[-2:] == preds_T.shape[-2:],'the output dim of teacher and student differ'
    N,C,W,H = preds_S.shape

    if self.align is not None:
        preds_S = self.align(preds_S)
    # 這裡的歸一化方式是唯一能展現 channel 的地方
    # 對每個channel的所有元素進行歸一化,然後讓學生網絡學習歸一化後的通道特征
    softmax_pred_T = F.softmax(preds_T.view(-1,W*H)/self.tau, dim=1) #[NxC, 32768]
    logsoftmax = torch.nn.LogSoftmax(dim=1)
    loss = torch.sum( - softmax_pred_T * logsoftmax(preds_S.view(-1,W*H)/self.tau)) * (self.tau ** 2)
    return self.loss_weight * loss / (C * N)
           

最終的 loss 如下:

然後在

mmseg/models/segmentors/base.py

中,求 loss 的和:

loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
log_vars['loss'] = loss
           
{
'loss':
	 tensor(55.8550, device='cuda:0', grad_fn=<AddBackward0>), 
'log_vars': 
	OrderedDict([('decode_0.loss_seg', 1.0829237699508667), 
				('decode_0.acc_seg', 10.901641845703125), 
				('decode_1.loss_seg', 2.7209525108337402), 
				('decode_1.acc_seg', 2.446269989013672), 
				('loss_cwd', 52.051116943359375), 
				('loss', 55.8549919128418)]), 
				'num_samples': 2
}
           

繼續閱讀