【GiantPandaCV导语】这是CenterNet系列的最后一篇。本文主要讲CenterNet在推理过程中的数据加载和后处理部分代码。最后提供了一个已经配置好的数据集供大家使用。
代码注释在:https://github.com/pprp/SimpleCVReproduction/tree/master/CenterNet
由于CenterNet是生成了一个heatmap进行的目标检测,而不是传统的基于anchor的方法,所以训练时候的数据加载和测试时的数据加载结果是不同的。并且在测试的过程中使用到了Test Time Augmentation(TTA),使用到了多尺度测试,翻转等。
在CenterNet中由于不需要非极大抑制,速度比较快。但是CenterNet如果在测试的过程中加入了多尺度测试,那就回调用soft nms将不同尺度的返回的框进行抑制。
以上是eval过程的数据加载部分的代码,主要有两个需要关注的点:
如果是多尺度会根据test_scale的值返回不同尺度的结果,每个尺度都有img,center等信息。这部分代码可以和test.py代码的多尺度处理一块理解。
尺度处理部分,有一个padding参数
这部分代码作用就是通过按位或运算,找到最接近的2的倍数-1作为最终的尺度。
例如:输入512,多尺度开启:0.5,0.7,1.5,那最终的结果是
512 x 0.5 | 31 = 287
512 x 0.7 | 31 = 383
512 x 1.5 | 31 = 799

上图是CenterNet的结构图,使用的是PlotNeuralNet工具绘制。在推理阶段,输入图片通过骨干网络进行特征提取,然后对下采样得到的特征图进行预测,得到三个头,分别是offset head、wh head、heatmap head。
推理过程核心工作就是从heatmap提取得到需要的bounding box,具体的提取方法是使用了一个3x3的最大化池化,检查当前热点的值是否比周围8个临近点的值都大。然后取100个这样的点,再做筛选。
以上过程的核心函数是:
<code>ctdet_decode</code>这个函数功能就是将heatmap转化成bbox:
第一步
将hmap归一化,使用了sigmoid函数
第二步
进入<code>_nms</code>函数:
hmax代表特征图经过3x3卷积以后的结果,keep为极大点的位置,返回的结果是筛选后的极大值点,其余不符合8-近邻极大值点的都归为0。
这时候通过heatmap得到了满足8近邻极大值点的所有值。
这里的nms曾经在群里讨论过,有群友认为仅通过3x3的并不合理,可以尝试使用3x3,5x5,7x7这样的maxpooling,相当于也进行了多尺度测试,据说能提高一点点mAP。
第三步
进入<code>_topk</code>函数,这里K是一个超参数,CenterNet中设置K=100
torch.topk的一个demo如下:
topk_scores和topk_inds分别是前K个score和对应的id。
topk_scores 形状【batch, class, K】K代表得分最高的前100个点, 其保存的内容是每个类别前100个最大的score。
topk_inds 形状 【batch, class, K】class代表80个类别channel,其保存的是每个类别对应100个score的下角标。
topk_score 形状 【batch, K】,通过gather feature 方法获取,其保存的是全部类别前100个最大的score。
topk_ind 形状 【batch , K】,代表通过topk调用结果的下角标, 其保存的是全部类别对应的100个score的下角标。
topk_inds、topk_ys、topk_xs三个变量都经过gather feature函数,其主要功能是从对应张量中根据下角标提取结果,具体函数如下:
以topk_inds为例(K=100,class=80)
feat (topk_inds) 形状为:【batch, 80x100, 1】
ind (topk_ind) 形状为:【batch,100】
<code>ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)</code>扩展一个位置,ind形状变为:【batch, 100, 1】
<code>feat = feat.gather(1, ind)</code>按照dim=1获取ind,为了方便理解和回忆,这里举一个例子:
相当于是feat根据ind的角标的值获取到了对应feat位置上的结果。最终feat形状为【batch,100,1】
第四步
经过topk函数,得到了四个返回值,topk_score、topk_inds、topk_ys、topk_xs四个参数的形状都是【batch, 100】,其中topk_inds是每张图片的前100个最大的值对应的index。
<code>regs = _tranpose_and_gather_feature(regs, inds)</code>
<code>w_h_ = _tranpose_and_gather_feature(w_h_, inds)</code>
transpose_and_gather_feat函数功能是将topk得到的index取值,得到对应前100的regs和wh的值。
到这一步为止,可以将top100的score、wh、regs等值提取,并且得到对应的bbox,最终ctdet_decode返回了detections变量。
之前在CenterNet系列第一篇PyTorch版CenterNet训练自己的数据集中讲解了如何配置数据集,为了更方便学习和调试这部分代码,笔者从github上找到了一个浣熊数据集,这个数据集仅有200张图片,方便大家快速训练和debug。
链接:https://pan.baidu.com/s/1unK-QZKDDaGwCrHrOFCXEA 提取码:pdcv
以上数据集已经制作好了,只要按照第一篇文章中将DCN、NMS等编译好,就可以直接使用。
https://blog.csdn.net/fsalicealex/article/details/91955759
https://zhuanlan.zhihu.com/p/66048276
https://zhuanlan.zhihu.com/p/85194783
代码改变世界