天天看點

注意力機制-CA注意力-Coordinate attention

注意力機制學習--CA(Coordinate attention)

    • 簡介
        • CA注意力機制的優勢:
        • 提出不足
    • 算法流程圖
    • 代碼
    • 最後

簡介

CA(Coordinate attention for efficient mobile network design)發表在CVPR2021,幫助輕量級網絡漲點、即插即用。

CA注意力機制的優勢:

1、不僅考慮了通道資訊,還考慮了方向相關的位置資訊。

2、足夠的靈活和輕量,能夠簡單的插入到輕量級網絡的核心子產品中。

提出不足

1、SE注意力中隻關注建構通道之間的互相依賴關系,忽略了空間特征。

2、CBAM中引入了大尺度的卷積核提取空間特征,但忽略了長程依賴問題。

算法流程圖

注意力機制-CA注意力-Coordinate attention

step1: 為了避免空間資訊全部壓縮到通道中,這裡沒有使用全局平均池化。為了能夠捕獲具有精準位置資訊的遠端空間互動,對全局平均池化進行的分解,具體如下:

注意力機制-CA注意力-Coordinate attention
注意力機制-CA注意力-Coordinate attention

對尺寸為 C ∗ H ∗ W C*H*W C∗H∗W輸入特征圖 I n p u t Input Input分别按照 X X X方向和 Y Y Y方向進行池化,分别生成尺寸為 C ∗ H ∗ 1 C*H*1 C∗H∗1和 C ∗ 1 ∗ W C*1*W C∗1∗W的特征圖。如下圖所示(圖檔粘貼自B站大佬渣渣的熊貓潘)。

注意力機制-CA注意力-Coordinate attention

step2:将生成的 C ∗ 1 ∗ W C*1*W C∗1∗W的特征圖進行變換,然後進行concat操作。公式如下:

注意力機制-CA注意力-Coordinate attention

将 z h z^h zh和 z w z^w zw進行concat後生成如下圖所示的特征圖,然後進行F1操作(利用1*1卷積核進行降維,如SE注意力中操作)和激活操作,生成特征圖 f ∈ R C / r × ( H + W ) × 1 f \in \mathbb{R}^{C/r\times(H+W)\times1} f∈RC/r×(H+W)×1。

注意力機制-CA注意力-Coordinate attention

step3:沿着空間次元,再将 f f f進行split操作,分成 f h ∈ R C / r × H × 1 f^h\in \mathbb{R}^{C/r\times H \times1} fh∈RC/r×H×1和 f w ∈ R C / r × 1 × W f^w\in \mathbb{R}^{C/r\times1\times W} fw∈RC/r×1×W,然後分别利用 1 × 1 1 \times 1 1×1卷積進行升次元操作,再結合sigmoid激活函數得到最後的注意力向量 g h ∈ R C × H × 1 g^h \in \mathbb{R}^{C \times H \times 1 } gh∈RC×H×1和 g w ∈ R C × 1 × W g^w\in \mathbb{R}^{C \times1\times W} gw∈RC×1×W。

注意力機制-CA注意力-Coordinate attention

最後:Coordinate Attention 的輸出公式可以寫成:

注意力機制-CA注意力-Coordinate attention

代碼

代碼粘貼自github。CoordAttention

位址:https://github.com/houqb/CoordAttention/blob/main/mbv2_ca.py

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, groups=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // groups)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.conv2 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv3 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.relu = h_swish()

    def forward(self, x):
        identity = x
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.relu(y) 
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        x_h = self.conv2(x_h).sigmoid()
        x_w = self.conv3(x_w).sigmoid()
        x_h = x_h.expand(-1, -1, h, w)
        x_w = x_w.expand(-1, -1, h, w)

        y = identity * x_w * x_h

        return y
           

最後

CA不僅考慮到空間和通道之間的關系,還考慮到長程依賴問題。通過實驗發現,CA不僅可以實作精度提升,且參數量、計算量較少。

簡單進行記錄,如有問題請大家指正。

繼續閱讀