天天看點

簡單了解反向注意力(Reverse Attention)機制

反向注意力(Reverse Attention)機制由《Reverse Attention for Salient Object Detection》一文提出。其核心思想為,在顯著目标檢測(二分割)網絡中,對象的大緻全局位置資訊在網絡的深層便可以獲得,是以Decoder的淺層部分隻需要關注對象的局部細節即可。具體做法則是,将decoder深層的輸出給取反,那麼網絡關注的位置即為對象以外的邊緣部分,進而使得最終結果局部細節更加出色。

Reverse Attention的結構如下圖所示:

簡單了解反向注意力(Reverse Attention)機制

代碼(取自原文github倉庫)如下:

class RA(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(RA, self).__init__()
        self.convert = nn.Conv2d(in_channel, out_channel, 1)
        self.convs = nn.Sequential(
            nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(out_channel, out_channel, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(out_channel, 1, 3, padding=1),
        )
        self.channel = out_channel
	
	# x:待被施加空間注意力的淺層特征
	# y:用于計算reverse attention map的深層特征
    def forward(self, x, y):
        a = torch.sigmoid(-y)	# reverse并壓縮至0~1區間内以用作空間注意力map
        x = self.convert(x)		# 統一x, y通道數
        x = a.expand(-1, self.channel, -1, -1).mul(x)	# x, y相乘,完成空間注意力
        y = y + self.convs(x)	# 殘差連接配接(圖中未畫出)
        return y
           

繼續閱讀