天天看點

CNN模型之SeNet總結

  • Squeeze-and-Excitation Networks(SENet)是自動駕駛公司Momenta2017年公布的圖像識别結構。
  • SE block網絡結構示意圖如下
CNN模型之SeNet總結
  • SE Net關鍵知識點
    • SE網絡可以通過堆疊SE子產品得到
    • SE子產品也可以嵌入到現在幾乎所有的網絡結構中
    • 前面層中的SE block以類别無關(class agnostic)的方式增強可共享的低層表示的品質
    • 越後面層中SE block越來越類别相關
    • SE block重新調整特征的益處可以在整個網絡中積累
  • SE block應用示例
CNN模型之SeNet總結
  • SE block代碼
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import mxnet as mx
from mxnet.gluon import nn
__all__ = ['SE']

class SE(mx.gluon.HybridBlock):
    """
    SE block
    refer to paper <Squeeze-and-Excitation Networks>
    """
    def __init__(self, units, r=16, **kwargs):
        super(SE, self).__init__(**kwargs)
        self.body = nn.HybridSequential()
        self.body.add(nn.Conv2D(channels=units // r, kernel_size=1))
        self.body.add(Activation('relu'))
        self.body.add(nn.Conv2D(channels=units, kernel_size=1))
        self.body.add(Activation('sigmoid'))

    def hybrid_forward(self, F, x):
        # (N, C, H, W) --> (N, C, 1, 1)
        w = F.contrib.AdaptiveAvgPooling2D(x, output_size=1)
        w = self.body(w)   # (N, C, 1, 1) --> (N, C, 1, 1)
        y = F.broadcast_mul(x, w)
        return y
#---------------------------------------------------------------------------------
class Activation(nn.Activation):
    def __init__(self, activation, **kwargs):
        super(Activation, self).__init__(activation, **kwargs)
        attrs = AttributeScope.current.attrs
        self.mirroring_level = attrs.get('mirroring_level', 0)

    def hybrid_forward(self, F, x):
        _kwargs = dict()
        if F is mx.symbol and self.mirroring_level >= 1:
            _kwargs['force_mirroring'] = 'True'
        return F.Activation(x, act_type=self._act_type, name='fwd', **_kwargs)
#---------------------------------------------------------------------------------
class AttributeScope(object):
    current = None
    def __init__(self, attrs={}):
        self.old_scope = None
        self.attrs = attrs

    def __enter__(self):
        self.old_scope = AttributeScope.current
        attrs = AttributeScope.current.attrs.copy()
        attrs.update(self.attrs)
        self.attrs = attrs
        AttributeScope.current = self
        return self

    def __exit__(self, ptype, value, trace):
        assert self.old_scope
        AttributeScope.current = self.old_scope


AttributeScope.current = AttributeScope()