天天看點

CCNet: Criss-Cross Attention for Semantic Segmentation

論文名稱:CCNet: Criss-Cross Attention for Semantic Segmentation

作者:Zilong Huang,Xinggang Wang Yun,chao Wei,Lichao Huang,Wenyu Liu,Thomas S. Huang

Code:https://github.com/speedinghzl/CCNet

摘要

上下文資訊在視覺了解問題中至關重要,譬如語義分割和目标檢測;

本文提出了一種十字交叉的網絡(Criss-Cross Net)以非常高效的方式擷取完整的圖像上下文資訊:

  1. 對每個像素使用一個十字注意力子產品聚集其路徑上所有像素的上下文資訊;
  2. 通過循環操作,每個像素最終都可以捕獲完整的圖像相關性;
  3. 提出了一種類别一緻性損失來增強子產品的表現。

CCNet具有一下優勢:

  1. 顯存友好:相較于Non-Local減少顯存占用11倍
  2. 計算高效:循環十字注意力減少Non-Local約85%的計算量
  3. SOTA
  4. Achieve the mIoU scores of 81.9%, 45.76% and 55.47% on the Cityscapes test set, the ADE20K validation set and the LIP validation set respectively

介紹

  • 目前FCN在語義分割任務取得了顯著進展,但是由于固定的幾何結構,分割精度局限于FCN局部感受野所能提供的短程感受野,目前已有相當多的工作緻力于彌補FCN的不足,相關工作看論文。
  • 密集預測任務實際上需要高分辨率的特征映射,是以Non-Local的方法往往計算複雜度高,并且占用大量顯存,是以設想使用幾個連續的稀疏連通圖(sparsely-connected graphs)來替換常見的單個密集連通圖( densely-connected graph),提出CCNet使用稀疏連接配接來代替Non-Local的密集連接配接。
  • 為了推動循環十字注意力學習更多的特征,引入了類别一緻損失(category consistent loss)來增強CCNet,其強制網絡将每個像素映射到特征空間的n維向量,使屬于同一類别的像素的特征向量靠得更近。

方法

CCNet可能是受到之前将卷積運算分解為水準和垂直的GCN以及模組化全局依賴性的Non-Local,CCNet使用的十字注意力相較于分解更具優勢,擁有比Non-Local小的多得計算量。

文中認為CCNet是一種圖神經網絡,特征圖中的每個像素都可以被視作一個節點,利用節點間的關系(上下文資訊)來生成更好的節點特征。

最後,提出了同時利用時間和空間上下文資訊的3D十字注意子產品。

網絡結構

CCNet: Criss-Cross Attention for Semantic Segmentation

整體流程如下:

  1. 對于給定的 X X X,使用卷積層獲得降維的特征映射 H H H;
  2. H H H會輸入十字注意力子產品以生成新的特征映射 H ′ H' H′​,其中每個像素都聚集了垂直和水準方向的資訊;
  3. 進行一次循環,将 H ′ H' H′輸入十字注意力,得到 H ′ ′ H'' H′′,其中每個像素實際上都聚集了所有像素的資訊;
  4. 将 H ′ ′ H'' H′′與局部特征表示 X X X進行 C o n c a t e n a t i o n Concatenation Concatenation​​;
  5. 由後續的網絡進行分割。

Criss-Cross Attention

CCNet: Criss-Cross Attention for Semantic Segmentation

主要流程如下:

  1. 使用 1 × 1 1\times 1 1×1​卷積進行降維得到 Q , K ∈ R C ′ × W × H Q,K \in \mathbb{R}^{C' \times W\times H} Q,K∈RC′×W×H​;
  2. 通過Affinity操作生成注意力圖 A ∈ R ( H + W − 1 ) × ( H × W ) A\in\mathbb{R}^{(H+W-1)\times (H\times W)} A∈R(H+W−1)×(H×W)​,其中:
    1. 對于 Q Q Q​空間次元上的的每一個位置 u u u​,我們可以得到一個向量 Q u ∈ R C ′ Q_u\in\mathbb{R}^{C'} Qu​∈RC′​;​​​
    2. 同時,我們在 K K K​上得到一個集合 Ω u ∈ R ( H + W − 1 ) × C ′ \Omega_u \in \mathbb{R}^{(H+W-1) \times C'} Ωu​∈R(H+W−1)×C′​​,其代表着位置 u u u​​​的同一行或同一列;
    3. 令 Ω i , u \Omega_{i,u} Ωi,u​​表示 Ω u \Omega_{u} Ωu​​的第 i i i個元素,Affinity操作可以表示為:

      d i , u = Q u Ω i , u T i ∈ [ 0 , 1 , ⋯   , H + W − 1 ] , u ∈ [ 0 , 1 , ⋯   , H × W ] d_{i,u}=Q_u\Omega_{i,u}^T\qquad i\in [0,1,\cdots,H+W-1],u\in[0,1,\cdots,H\times W] di,u​=Qu​Ωi,uT​i∈[0,1,⋯,H+W−1],u∈[0,1,⋯,H×W]

      其用來表示兩者之間的相關性,最終我們可以得到 D ∈ R ( H + W − 1 ) × ( H × W ) D\in\mathbb{R}^{(H+W-1)\times (H\times W)} D∈R(H+W−1)×(H×W)​​

    4. 最終在通道次元上對 D D D​使用 S o f t m a x Softmax Softmax​,即可得到注意力圖 A A A​,需要注意的是,這裡的通道次元代表的是 H + W − 1 H+W-1 H+W−1​​​​這個次元,其表示某個位置像素與其垂直水準方向上像素的相關性。
  3. 另一方面,依舊使用 1 × 1 1\times 1 1×1卷積生成 V ∈ R C × W × H V \in \mathbb{R}^{C \times W \times H} V∈RC×W×H,我們可以獲得一個向量 V u ∈ R C V_u\in \mathbb{R}^C Vu​∈RC和一個集合 Φ u ∈ R ( H + W − 1 ) × C \Phi_u\in \mathbb{R}^{(H+W-1)\times C} Φu​∈R(H+W−1)×C​
  4. 最後使用Aggregation操作得到最終的特征圖,其定義為:

    H u ′ = ∑ i = 0 H + W − 1 A i , u Φ i , u + H u H'_u=\sum_{i=0}^{H+W-1}A_{i,u}\Phi_{i,u}+H_u Hu′​=i=0∑H+W−1​Ai,u​Φi,u​+Hu​

    其中 H u ′ ∈ R C H'_u\in\mathbb{R}^{C} Hu′​∈RC​​​是某個位置的特征向量。

至此,我們已經能夠捕獲某個位置像素水準和垂直方向上的文本資訊,然而,該像素與周圍的其他像素仍然不存在關系,為了解決這個問題,提出了循環機制。

Recurrent Criss-Cross Attention (RCCA)

通過多次使用CCA來達到對上下文進行模組化,當循環次數R=2時,特征圖中任意兩個空間位置的關系可以定義為:

∃ i ∈ R H + W + 1 , s . t . A i , u = f ( A , u x C C , u y C C , u x , u y ) \exist i\in\mathbb{R}^{H+W+1},s.t.A_{i,u}=f(A,u_{x}^{CC},u^{CC}_y,u_x,u_y) ∃i∈RH+W+1,s.t.Ai,u​=f(A,uxCC​,uyCC​,ux​,uy​)

友善起見,對于特征圖上的兩個位置 ( u x , u y ) (u_x,u_y) (ux​,uy​)和 ( θ x , θ y ) (\theta_x,\theta_y) (θx​,θy​),其資訊傳遞示意圖如下:

CCNet: Criss-Cross Attention for Semantic Segmentation

可以看到,經過兩次循環,原本不相關的位置也能夠建立聯系了。

Learning Category Consistent Features

對于語義分割任務,屬于同一類别的像素應該具有相似的特征,而來自不同類别的像素應該具有相距很遠的特征。

然而,聚集的特征可能存在過度平滑的問題,這是圖神經網絡中的一個常見問題,為此,提出了類别一緻損失。

l v a r = 1 ∣ C ∣ ∑ c ∈ C 1 N c ∑ i = 1 N c φ v a r ( h i , μ i ) l_{var}=\frac{1}{|C|}\sum_{c\in C}\frac{1}{N_c}\sum_{i=1}^{N_c}\varphi_{var}(h_i,\mu_i) lvar​=∣C∣1​c∈C∑​Nc​1​i=1∑Nc​​φvar​(hi​,μi​)

l d i s = 1 ∣ C ∣ ( ∣ C ∣ − 1 ) ∑ c a ∈ C ∑ c b ∈ C φ d i s ( μ c a , μ c b ) l_{dis}=\frac{1}{|C|(|C|-1)}\sum_{c_a\in C}\sum_{c_b\in C}\varphi_{dis}(\mu_{c_a},\mu_{c_b}) ldis​=∣C∣(∣C∣−1)1​ca​∈C∑​cb​∈C∑​φdis​(μca​​,μcb​​)

l r e g = 1 ∣ C ∣ ∑ c ∈ C ∣ ∣ μ c ∣ ∣ l_{reg}=\frac{1}{|C|}\sum_{c\in C}||\mu_c|| lreg​=∣C∣1​c∈C∑​∣∣μc​∣∣

其中的距離函數 φ \varphi φ設計為分段形式,公式如下:

φ v a r = { ∣ ∣ μ c − h i ∣ ∣ − δ d + ( δ d − δ v ) 2 , ∣ ∣ μ c − h i ∣ ∣ > δ d ( ∣ ∣ μ c − h i ∣ ∣ − δ v ) 2 , δ d > ∣ ∣ μ c − h i ∣ ∣ ⩾ δ v 0 ∣ ∣ μ c − h i ∣ ∣ ⩽ δ d \varphi_{var}=\left\{ \begin{array}{l} ||\mu_c-h_i||-\delta{_d}+(\delta{_d}-\delta{_v})^2,&||\mu_c-h_i||>\delta{_d}\\ (||\mu_c-h_i||-\delta{_v})^2,&\delta{_d}>||\mu_c-h_i||\geqslant\delta{_v}\\ 0 &||\mu_c-h_i||\leqslant\delta{_d} \end{array}\right. φvar​=⎩⎨⎧​∣∣μc​−hi​∣∣−δd​+(δd​−δv​)2,(∣∣μc​−hi​∣∣−δv​)2,0​∣∣μc​−hi​∣∣>δd​δd​>∣∣μc​−hi​∣∣⩾δv​∣∣μc​−hi​∣∣⩽δd​​

φ d i s = { ( 2 δ d − ∣ ∣ μ c a − μ c b ∣ ∣ ) 2 , ∣ ∣ μ c a − μ c b ∣ ∣ ⩽ 2 δ d 0 , ∣ ∣ μ c a − μ c b ∣ ∣ > 2 δ d \varphi_{dis}=\left\{\begin{array} {l} (2\delta{_d}-||\mu_{c_a}-\mu_{c_b}||)^2,&||\mu_{c_a}-\mu_{c_b}||\leqslant2\delta{_d}\\ 0,&||\mu_{c_a}-\mu_{c_b}||>2\delta{_d} \end{array}\right. φdis​={(2δd​−∣∣μca​​−μcb​​∣∣)2,0,​∣∣μca​​−μcb​​∣∣⩽2δd​∣∣μca​​−μcb​​∣∣>2δd​​

本文中,距離門檻值的設定為 δ v = 0.5 , δ d = 1.5 \delta{_v}=0.5,\delta{_d}=1.5 δv​=0.5,δd​=1.5

為了加速計算,對RCCA的輸入進行降維,其比率設定為16

總的損失函數定義如下:

l = l s e g + α l v a r + β l d i s + γ l r e g l=l_{seg}+\alpha l_{var}+\beta l_{dis}+\gamma l_{reg} l=lseg​+αlvar​+βldis​+γlreg​

本文中, α , β , γ \alpha,\beta,\gamma α,β,γ​​的值分别為1,1,0.001,

3D Criss-Cross Attention

在2D注意力的基礎上進行推廣,提出3DCCA,其可以在時間次元上收集額外的上下文資訊

CCNet: Criss-Cross Attention for Semantic Segmentation

其流程與2DCCA大緻相同,具體細節差異看論文。

代碼複現

Criss-Cross Attention

def INF(B,H,W):
    # tensor -> torch.size([H]) -> 對角矩陣[H,H] -> [B*W,H,H] 
    # 消除重複計算自身的影響
    return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
           
class CrissCrossAttention(nn.Module):
    """ Criss-Cross Attention Module"""
    def __init__(self, in_ch,ratio=8):
        super(CrissCrossAttention,self).__init__()
        self.q = nn.Conv2d(in_ch, in_ch//ratio, 1)
        self.k = nn.Conv2d(in_ch, in_ch//ratio, 1)
        self.v = nn.Conv2d(in_ch, in_ch, 1)
        self.softmax = nn.Softmax(ch=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1)) # 初始化為0


    def forward(self, x):
        bs, _, h, w = x.size()
        # Q
        x_q = self.q(x)
        # b,c',h,w -> b,w,c',h -> b*w,c',h -> b*w,h,c'
        # 後兩維相當于論文中的Q_u,在此分解為了
        x_q_H = x_q.permute(0,3,1,2).contiguous().view(bs*w,-1,h).permute(0, 2, 1)
        # b,c',h,w -> b,h,c',w -> b*h,c',w -> b*h,w,c'
        x_q_W = x_q.permute(0,2,1,3).contiguous().view(bs*h,-1,w).permute(0, 2, 1)
        # K
        x_k = self.k(x) # b,c',h,w
        # b,c',h,w -> b,w,c',h -> b*w,c',h
        x_k_H = x_k.permute(0,3,1,2).contiguous().view(bs*w,-1,h)
        # b,c',h,w -> b,h,c',w -> b*h,c',w
        x_k_W = x_k.permute(0,2,1,3).contiguous().view(bs*h,-1,w)
        # V
        x_v = self.v(x)
        # b,c,h,w -> b,w,c,h -> b*w,c,h
        x_v_H = x_v.permute(0,3,1,2).contiguous().view(bs*w,-1,h) 
        # b,c,h,w -> b,h,c,w -> b*h,c,w
        x_v_W = x_v.permute(0,2,1,3).contiguous().view(bs*h,-1,w)
        # torch.bmm計算三維的矩陣乘法,如[bs,a,b][bs,b,c]
        # 先計算所有Q_u和K上與位置u同一列的
        energy_H = (torch.bmm(x_q_H, x_k_H)+self.INF(bs, h, w)).view(bs,w,h,h).permute(0,2,1,3) # b,h,w,h
        # 再計算行
        energy_W = torch.bmm(x_q_W, x_k_W).view(bs,h,w,w)
        # 得到注意力圖
        concate = self.softmax(torch.cat([energy_H, energy_W], 3)) # b,h,w,h+w

        # 後面開始合成一張圖
        att_H = concate[:,:,:,0:h].permute(0,2,1,3).contiguous().view(bs*w,h,h)
        #print(concate)
        #print(att_H) 
        att_W = concate[:,:,:,h:h+w].contiguous().view(bs*h,w,w)
        # 同樣的計算方法
        out_H = torch.bmm(x_v_H, att_H.permute(0, 2, 1)).view(bs,w,-1,h).permute(0,2,3,1) # b,c,h,w
        out_W = torch.bmm(x_v_W, att_W.permute(0, 2, 1)).view(bs,h,-1,w).permute(0,2,1,3) # b,c,h,w
        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x # 乘積使得整體可訓練
           

Category Consistent Loss

未找到代碼

實驗

在Cityscapes、ADE20K、COCO、LIP和CamVid資料集上進行了實驗,在一些資料集上實作了SOTA,并且在Cityscapes資料集上進行了消融實驗。

實驗結果

在Cityscapes上的結果:

CCNet: Criss-Cross Attention for Semantic Segmentation

消融實驗

RCCA子產品

通過改變循環次數進行了如下實驗:

CCNet: Criss-Cross Attention for Semantic Segmentation

可以看到,RCCA子產品可以有效的聚集全局上下文資訊,同時保持較低的計算量。

為了進一步驗證CCA的有效性,進行了定性比較:

CCNet: Criss-Cross Attention for Semantic Segmentation

随着循環次數的增加,這些白色圈圈區域的預測逐漸得到糾正,這證明了密集上下文資訊在語義分割中的有效性。

類别一緻損失

CCNet: Criss-Cross Attention for Semantic Segmentation

上圖中的CCL即表示使用了類别一緻損失

CCNet: Criss-Cross Attention for Semantic Segmentation

上述結果表明了分段距離和類别一緻損失的有效性。

對比其他聚集上下文資訊方法

CCNet: Criss-Cross Attention for Semantic Segmentation

同時,對Non Local使用了循環操作,可以看到,循環操作帶來了超過一點的增益,然而其巨量的計算量和顯存需求限制性能

CCNet: Criss-Cross Attention for Semantic Segmentation

可視化注意力圖

CCNet: Criss-Cross Attention for Semantic Segmentation

上圖中可以看到循環操作的有效性。

更多實驗

在ADE20K上的實驗驗證了類别一緻損失(CCL)的有效性:

CCNet: Criss-Cross Attention for Semantic Segmentation

在LIP資料集的實驗結果:

CCNet: Criss-Cross Attention for Semantic Segmentation

在COCO資料集的實驗結果:

CCNet: Criss-Cross Attention for Semantic Segmentation

在CamVid資料上的實驗結果:

CCNet: Criss-Cross Attention for Semantic Segmentation
cv dl

繼續閱讀