天天看点

YOLO算法改进Backbone系列之:HAT-Net

作者:Nuist目标检测

本文旨在解决ViT中与多头自我关注(MHSA)相关的高计算/空间复杂性问题。为此,我们提出了分层多头自注意(H-MHSA),这是一种以分层方式计算自注意的新方法。具体来说,我们首先按照通常的方法将输入图像划分为多个斑块,每个斑块被视为一个标记。然后,提议的 H-MHSA 学习局部补丁内的标记关系,作为局部关系建模。然后,将小补丁合并成大补丁,H-MHSA 对合并后的少量标记进行全局依赖关系建模。最后,对局部和全局注意力特征进行汇总,以获得具有强大表征能力的特征。由于我们每一步只计算有限数量标记的注意力,因此计算负荷大大减少。

因此,H-MHSA 可以在不牺牲细粒度信息的情况下,有效地模拟标记之间的全局关系。有了 H-MHSA 模块,我们构建了一个基于分层注意力的变换器网络系列,即 HAT-Net。为了证明 HAT-Net 在场景理解方面的优越性,我们在图像分类、语义分割、物体检测和实例分割等基本视觉任务上进行了大量实验。因此,HAT-Net 为视觉转换器提供了一个新的视角。

现有问题及解决方案:Transformer在NLP领域中以成为了处理长距离依赖关系的事实标准,但其依赖于自注意力机制来建模序列数据的全局关系。随着视觉Transformer的代表性工作ViT的出现,基于像素patch构建Transformer模型的方式已经成为了视觉Transformer的主流范式,但是由于视觉数据中patch序列长度依然较长,其所依赖的Self-Attention操作在实际应用中仍然面临着较高的计算量和空间复杂度的问题。

最近的一些工作主要在尝试通过各种手段来压缩序列长度从而提升视觉Transformer的计算效率,主要如下:

  • Local Attention:Swin Transformer中使用固定大小的窗口,并搭配Shift Window并多层堆叠从而模拟全局建模,这种手段仍然次优,因为其仍然延续着CNN的堆叠模拟长距离依赖的思路
  • Pooling Attention:PVT对特征图下采样,从而缩小了序列长度。但是因为下采样了key和value,也因此丢失了局部细节,而且使用了固定尺寸的下采样比例,这使用的是具有与卷积核大小相同的步长的跨步卷积实现的;另外如果需要调整配置,就得需要重新训练
  • Channel Attention:CoaT计算了通道形式的注意力,这可能没有模拟全局特征依赖那么有效

针对MHSA提出了一种更加有效和灵活的变体—分层多头自注意力(Hierarchical Multi-Head Self-Attention,H-MHSA)。其通过将直接计算全局相似关系的MHSA拆解成了多个步骤,每步中具有不同粒度的短序列之间的相似性建模,从而既保留了细粒度信息,又保留了短序列计算的高效。而且H-MHSA涉及到缩短序列的操作都是无参数的,所以对于下游任务更加灵活,不需要因为调整而重新预训练。具体而言,H-MHSA中包含一下几个步骤:

  • 对于输入的query、key以及value对应的patch token,首先将它们进行分组,分成不重叠的数个grid
  • 在grid内的patch之间计算attention,从而捕获局部关系,产生更具判别性的局部表征。这里是基于残差形式
  • 将这些小patch合并,获得更大层级的patch token。这允许我们直接基于这些数量较少的粗粒度的token来模拟全局依赖关系。这里计算时,对k、v使用平均池化进行进行压缩处理。
  • 最后来自局部和全局层级的特征被集成,从而获得具有更加丰富粒度的特征
YOLO算法改进Backbone系列之:HAT-Net

下表总结了HAT-Net模型的不同配置列表

YOLO算法改进Backbone系列之:HAT-Net

在YOLOv5项目中添加模型作为Backbone使用的教程:

(1)将YOLOv5项目的models/yolo.py修改parse_model函数以及BaseModel的_forward_once函数

YOLO算法改进Backbone系列之:HAT-Net
YOLO算法改进Backbone系列之:HAT-Net

(2)在models/backbone(新建)文件下新建HAT_Net.py,添加如下的代码:

YOLO算法改进Backbone系列之:HAT-Net

(3)在models/yolo.py导入模型并在parse_model函数中修改如下(先导入文件):

YOLO算法改进Backbone系列之:HAT-Net

(4)在model下面新建配置文件:yolov5_hatnet.yaml

YOLO算法改进Backbone系列之:HAT-Net

(5)运行验证:在models/yolo.py文件指定--cfg参数为新建的yolov5_hatnet.yaml

YOLO算法改进Backbone系列之:HAT-Net

继续阅读