天天看点

Few-Shot Object Detection via Sample Processing

Few-Shot Object Detection via Sample Processing

来源 IEEE Access

Introduction

  • 正样本的缺失是限制少样本目标检测性能的主要原因,样本太少导致的结果是过拟合和模型泛化能力差
  • 本文从少样本目标检测的样本角度出发,对于一个样本 a 和两个类别 M、N ,如果 a 的标签为类别 M 那么对于 M 来说它是正样本,对于 N 来说他就是负样本。另外根据模型判断的难易程度,它又可以定义为难样本和易样本。结合这两种维度的定义,样本类型可以分为四个类别:easy-positive sample、hard-positive sample、easy-negative sample以及hard-negative sample.
  • 传统方法在少样本目标检测上遭遇的困境有以下三点:
    • 少量样本不足以描述类别的特征
    • 少量样本导致空间尺度分布缺失
    • 正样本的缺失导致负样本的增长
  • 针对以上问题,提出通过 sample processing 的样本目标检测的方法,基于yolov3-spp,引入自注意力模块(SAM)和正样本增强模块(PSA),同时在微调阶段修改网络的损失函数,加大对难样本的惩罚。

Annotation Space

Few-Shot Object Detection via Sample Processing

图中以三个类别为例描述标签空间的相关概念,三个圆中心代表的是类别的公共特征,根据样本离类别中心的距离将样本分为难样本和易样本。如上图所示,g对于A来说是一个 easy-positive sample,而h对于A来说是一个hard-positive sample,同时h还落入了B的区域,这对于A是有价值的,但同时对于B来说是一个威胁。

Proposed Method

少样本目标检测一个基本模型流程

Few-Shot Object Detection via Sample Processing

基于以上流程和sample processing的思路,本文的模型整体设计如下

Few-Shot Object Detection via Sample Processing
  • 模型主要由两个分支组成 basic-trunk 和 reinforcement-branch,基于孪生神经网络的reinforcement-branch 共享 basic-trunk 的权重
  • 训练过程 base training 和 fine-tuning.

    base training 的时候 reinforcement-branch 是冻结的

    fine-tuning 的时候两个分支同时执行

YOLOv3-SPP

相对于普通的 YOLOv3 在第五和第六个卷积层中加入 SPP 层,由 1 × 1 , 5 × 5 , 9 × 9 , 13 × 13 1 \times 1, 5\times5,9\times9,13\times13 1×1,5×5,9×9,13×13​​的最大池化操作组成​​​​,用以丰富特征图的尺度分布。

Few-Shot Object Detection via Sample Processing

Self-Attention Module

自注意力模块的实现

将输入的特征通过不同的池化和卷积操作映射到不同的特征空间 f(x), g(x) 和 h(x). 最大池化操作保留目标的 texture feature,平均池化操作可以保留背景信息,上述过程可以用下面的公式表示

θ ( x ) = C θ [ p θ ( x ) ] ,   θ ∈ { f , g , h } \theta (x)=C_{\theta} [p_\theta(x)] ,\ \theta \in \{f,g,h\} θ(x)=Cθ​[pθ​(x)], θ∈{f,g,h}

接着将 f(x) 和 g(x) 组合通过 softmax 层生成 α

α = s o f t m a x ( f ( x ) T g ( x ) ) \alpha = softmax(f(x)^Tg(x)) α=softmax(f(x)Tg(x))

特征图 α 结合 h(x) 生成注意力特征图 o

o j = ∑ i = 1 N α ( j , i ) h ( x i ) o_j = \sum_{i=1}^{N}\alpha_{(j,i)}h(x_i) oj​=i=1∑N​α(j,i)​h(xi​)

自注意力特征图 o 结合输入特征 x 生成最终输出特征图为 y

y i = γ o i + x i y_i= \gamma o_i + x_i yi​=γoi​+xi​

Few-Shot Object Detection via Sample Processing

Positive-Sample Augmentation Module

样本数量少极易造成模型的过拟合,同时也没有足够数量的尺度分布,多尺度的检测方法也会失效,提出 PSA 模块来增强数据集,主要包括 background sparsity, multiscale replication 和 random clipping.

Few-Shot Object Detection via Sample Processing
  • background sparsity 类似与mask操作,背景区域像素点置零
  • multiscale replication 图片的缩放
  • random clipping 将裁剪的图片置于不同的位置,提高分散程度

Modified Loss Function

样本数量少的情况下容易受相似样本的影响,因此本文主要针对的是微调阶段的类别损失函数进行改进。不仅需要给 positive samples 更高的分数,而且需要抑制 hard-negative samples,因此模型可以减少这些易混淆样本的影响,损失函数设计为

L f l = − ∑ i = 0 S × S I i j o b j ∑ [ α p i ^ ( c ) l o g ( p i ( c ) ) ] + β ( 1 − p i ( c ) ^ ) ( p i ( c ) ) ε l o g ( 1 − p i ( c ) ) L_{fl} = -\sum_{i=0}^{S\times S}I_{ij}^{obj}\sum[\alpha \hat{p_i}(c)log(p_i(c))]+\beta(1-\hat{p_i(c)})(p_i(c))^{\varepsilon}log(1-p_i(c)) Lfl​=−i=0∑S×S​Iijobj​∑[αpi​^​(c)log(pi​(c))]+β(1−pi​(c)^​)(pi​(c))εlog(1−pi​(c))

输入的图像被分成 S × S S \times S S×S 的网格, I i j o b j I_{ij}^{obj} Iijobj​ 是目标对应位置的锚框, p i ^ ( c ) \hat{p_i}(c) pi​^​(c) 和 p i ( c ) p_i(c) pi​(c) 分别真实值和预测值, ( p i ( c ) ) ε (p_i(c))^{\varepsilon} (pi​(c))ε​​​ 是添加的权重值​,当真实值为0时,预测值越高表示越难排除目标,添加权重来改进难样本。

同时,hard-positive samples 需要获得更多关注,因此损失函数增加 L C S L_{CS} LCS​ 项,利用余弦相似度来减少类间相似度, L C S L_{CS} LCS​ 定义如下:

L C S = − ∑ i = 0 S × S γ I i j o b j ∑ c ∈ c l a s s e s ( ω c T f i j ∣ ∣ ω c T ∣ ∣ ∣ ∣ f i j ∣ ) ε − 1 L_{CS} = -\sum_{i=0}^{S\times S}\gamma I_{ij}^{obj}\sum_{c \in classes}(\frac{\omega ^{T}_{c}f_{ij}}{||\omega_{c}^{T}|| ||f_{ij}|})^{\varepsilon -1} LCS​=−i=0∑S×S​γIijobj​c∈classes∑​(∣∣ωcT​∣∣∣∣fij​∣ωcT​fij​​)ε−1

f i j f_{ij} fij​表示对应位置的锚框的特征, ω c \omega_{c} ωc​ 表示类别 C 对应最后一层的权重向量, ω c T f i j \omega ^{T}_{c}f_{ij} ωcT​fij​ 的值越大对应说明正样本偏离类别比较大, ε \varepsilon ε 用来强调, ε − 1 \varepsilon - 1 ε−1​ 表示 hard-positive samples 应该比 hard-negative samples 对应的权重小,因为后者的值更大。

综上,类别损失函数为 L c l s = L f l + L C S L_{cls} = L_{fl} + L_{CS} Lcls​=Lfl​+LCS​,总体损失函数为 L o s s = L r e g + L c l s + L c o n f Loss = L_{reg} + L_{cls} + L_{conf} Loss=Lreg​+Lcls​+Lconf​

Experiments

VOC

Few-Shot Object Detection via Sample Processing

COCO

Few-Shot Object Detection via Sample Processing

继续阅读