論文: Cascade R-CNN Delving into High Quality Object Detection (CVPR 2018)
代碼:cascade-rcnn_Pytorch
文章目錄
-
- 為什麼級聯
-
-
- 總結
-
- 代碼梳理
- 實驗
- 參考文獻
為什麼級聯
雙階段網絡的典型代表就是Faster RCNN了,先通過RPN網絡産生Proposals,然後挑選出正負樣本,并配置設定标簽進行訓練。訓練時選擇哪些proposals作為正負樣本,一般是根據IoU門檻值來界定(一般取0.5),和實際目标框高于該門檻值就作為正樣本。
- 圖(a):門檻值設定得低,則正樣本中含較多背景,使得誤檢較多
- 圖(b):門檻值設定得過高,雖然能減少誤檢,正樣本的數量變少,容易過拟合,會漏檢
- 圖(c):當輸入的IoU分布與門檻值較為接近的時候,其輸出IoU也相對較高,說明将門檻值調節到輸入IoU附近時,得到的輸出一般有更好的表現
- 圖(d):随着門檻值增加,檢測器的性能會大緻在各自對應IoU區間有一個更好的表現,總體上門檻值為0.6時表現更好
從上圖(a)還有可以看出很重要的一點就是:大部分時候,曲線都在灰色曲線上方,說明輸出IoU一般比輸入IoU要高,将輸出繼續作為輸入,相當于調高了Proposals的IoU(下一輪的輸入IoU更高),同時适當調高IoU門檻值,以得到更高IoU的輸出,其結構示意圖如下圖(d)所示,這便是CascadeRCNN的級聯結構
從下圖可以看出,每一個階段的輸出的IoU分布明顯不同,越到後面輸出的品質越高,相當于一個進化的過程,越好的Proposals訓練效果越好
總結
- 輸入IoU的分布在門檻值附近時,訓練效果相對較好(可以根據輸入IoU分布調整IoU門檻值)
- 輸出IoU一般比輸入IoU高(可以級聯)
代碼梳理
相比FasterRCNN,其實就是将後面的FastRCNN部分重複
- stage1的RoIs1由RPN網絡産生,stage1輸出的預測框繼續在前面的feature map上提取相應的RoIs2,作為stage2的輸入,stage3以此類推
##################stage1##################
self.RCNN_top = nn.Sequential(
nn.Conv2d(256, 1024, kernel_size=cfg.POOLING_SIZE, stride=cfg.POOLING_SIZE, padding=0),
nn.ReLU(True),
nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
nn.ReLU(True)
)
self.RCNN_cls_score = nn.Linear(1024, self.n_classes)
self.RCNN_bbox_pred = nn.Linear(1024, 4 * self.n_classes)
##################stage2##################
self.RCNN_top_2nd = nn.Sequential(
nn.Conv2d(256, 1024, kernel_size=cfg.POOLING_SIZE, stride=cfg.POOLING_SIZE, padding=0),
nn.ReLU(True),
nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
nn.ReLU(True)
)
self.RCNN_cls_score_2nd = nn.Linear(1024, self.n_classes)
self.RCNN_bbox_pred_2nd = nn.Linear(1024, 4 * self.n_classes)
##################stage3##################
self.RCNN_top_3rd = nn.Sequential(
nn.Conv2d(256, 1024, kernel_size=cfg.POOLING_SIZE, stride=cfg.POOLING_SIZE, padding=0),
nn.ReLU(True),
nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
nn.ReLU(True)
)
self.RCNN_cls_score_3rd = nn.Linear(1024, self.n_classes)
self.RCNN_bbox_pred_3rd = nn.Linear(1024, 4 * self.n_classes)
實驗
參考文獻
【1】Cascade RCNN算法筆記
【2】Cascade R-CNN 詳細解讀
【3】cascade-rcnn_Pytorch