天天看點

【從零開始學CenterNet】7. CenterNet測試推理過程

【GiantPandaCV導語】這是CenterNet系列的最後一篇。本文主要講CenterNet在推理過程中的資料加載和後處理部分代碼。最後提供了一個已經配置好的資料集供大家使用。

代碼注釋在:https://github.com/pprp/SimpleCVReproduction/tree/master/CenterNet

由于CenterNet是生成了一個heatmap進行的目标檢測,而不是傳統的基于anchor的方法,是以訓練時候的資料加載和測試時的資料加載結果是不同的。并且在測試的過程中使用到了Test Time Augmentation(TTA),使用到了多尺度測試,翻轉等。

在CenterNet中由于不需要非極大抑制,速度比較快。但是CenterNet如果在測試的過程中加入了多尺度測試,那就回調用soft nms将不同尺度的傳回的框進行抑制。

以上是eval過程的資料加載部分的代碼,主要有兩個需要關注的點:

如果是多尺度會根據test_scale的值傳回不同尺度的結果,每個尺度都有img,center等資訊。這部分代碼可以和test.py代碼的多尺度處理一塊了解。

尺度處理部分,有一個padding參數

這部分代碼作用就是通過按位或運算,找到最接近的2的倍數-1作為最終的尺度。

例如:輸入512,多尺度開啟:0.5,0.7,1.5,那最終的結果是

512 x 0.5 | 31 = 287

512 x 0.7 | 31 = 383

512 x 1.5 | 31 = 799

【從零開始學CenterNet】7. CenterNet測試推理過程

上圖是CenterNet的結構圖,使用的是PlotNeuralNet工具繪制。在推理階段,輸入圖檔通過骨幹網絡進行特征提取,然後對下采樣得到的特征圖進行預測,得到三個頭,分别是offset head、wh head、heatmap head。

推理過程核心工作就是從heatmap提取得到需要的bounding box,具體的提取方法是使用了一個3x3的最大化池化,檢查目前熱點的值是否比周圍8個臨近點的值都大。然後取100個這樣的點,再做篩選。

以上過程的核心函數是:

<code>ctdet_decode</code>這個函數功能就是将heatmap轉化成bbox:

第一步

将hmap歸一化,使用了sigmoid函數

第二步

進入<code>_nms</code>函數:

hmax代表特征圖經過3x3卷積以後的結果,keep為極大點的位置,傳回的結果是篩選後的極大值點,其餘不符合8-近鄰極大值點的都歸為0。

這時候通過heatmap得到了滿足8近鄰極大值點的所有值。

這裡的nms曾經在群裡讨論過,有群友認為僅通過3x3的并不合理,可以嘗試使用3x3,5x5,7x7這樣的maxpooling,相當于也進行了多尺度測試,據說能提高一點點mAP。

第三步

進入<code>_topk</code>函數,這裡K是一個超參數,CenterNet中設定K=100

torch.topk的一個demo如下:

topk_scores和topk_inds分别是前K個score和對應的id。

topk_scores 形狀【batch, class, K】K代表得分最高的前100個點, 其儲存的内容是每個類别前100個最大的score。

topk_inds 形狀 【batch, class, K】class代表80個類别channel,其儲存的是每個類别對應100個score的下角标。

topk_score 形狀 【batch, K】,通過gather feature 方法擷取,其儲存的是全部類别前100個最大的score。

topk_ind 形狀 【batch , K】,代表通過topk調用結果的下角标, 其儲存的是全部類别對應的100個score的下角标。

topk_inds、topk_ys、topk_xs三個變量都經過gather feature函數,其主要功能是從對應張量中根據下角标提取結果,具體函數如下:

以topk_inds為例(K=100,class=80)

feat (topk_inds) 形狀為:【batch, 80x100, 1】

ind (topk_ind) 形狀為:【batch,100】

<code>ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)</code>擴充一個位置,ind形狀變為:【batch, 100, 1】

<code>feat = feat.gather(1, ind)</code>按照dim=1擷取ind,為了友善了解和回憶,這裡舉一個例子:

相當于是feat根據ind的角标的值擷取到了對應feat位置上的結果。最終feat形狀為【batch,100,1】

第四步

經過topk函數,得到了四個傳回值,topk_score、topk_inds、topk_ys、topk_xs四個參數的形狀都是【batch, 100】,其中topk_inds是每張圖檔的前100個最大的值對應的index。

<code>regs = _tranpose_and_gather_feature(regs, inds)</code>

<code>w_h_ = _tranpose_and_gather_feature(w_h_, inds)</code>

transpose_and_gather_feat函數功能是将topk得到的index取值,得到對應前100的regs和wh的值。

到這一步為止,可以将top100的score、wh、regs等值提取,并且得到對應的bbox,最終ctdet_decode傳回了detections變量。

之前在CenterNet系列第一篇PyTorch版CenterNet訓練自己的資料集中講解了如何配置資料集,為了更友善學習和調試這部分代碼,筆者從github上找到了一個浣熊資料集,這個資料集僅有200張圖檔,友善大家快速訓練和debug。

【從零開始學CenterNet】7. CenterNet測試推理過程
連結:https://pan.baidu.com/s/1unK-QZKDDaGwCrHrOFCXEA 提取碼:pdcv

以上資料集已經制作好了,隻要按照第一篇文章中将DCN、NMS等編譯好,就可以直接使用。

https://blog.csdn.net/fsalicealex/article/details/91955759

https://zhuanlan.zhihu.com/p/66048276

https://zhuanlan.zhihu.com/p/85194783

代碼改變世界

繼續閱讀