- Squeeze-and-Excitation Networks(SENet)是自動駕駛公司Momenta2017年公布的圖像識别結構。
- SE block網絡結構示意圖如下
CNN模型之SeNet總結 - SE Net關鍵知識點
- SE網絡可以通過堆疊SE子產品得到
- SE子產品也可以嵌入到現在幾乎所有的網絡結構中
- 前面層中的SE block以類别無關(class agnostic)的方式增強可共享的低層表示的品質
- 越後面層中SE block越來越類别相關
- SE block重新調整特征的益處可以在整個網絡中積累
- SE block應用示例
CNN模型之SeNet總結 from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import mxnet as mx
from mxnet.gluon import nn
__all__ = ['SE']
class SE(mx.gluon.HybridBlock):
"""
SE block
refer to paper <Squeeze-and-Excitation Networks>
"""
def __init__(self, units, r=16, **kwargs):
super(SE, self).__init__(**kwargs)
self.body = nn.HybridSequential()
self.body.add(nn.Conv2D(channels=units // r, kernel_size=1))
self.body.add(Activation('relu'))
self.body.add(nn.Conv2D(channels=units, kernel_size=1))
self.body.add(Activation('sigmoid'))
def hybrid_forward(self, F, x):
# (N, C, H, W) --> (N, C, 1, 1)
w = F.contrib.AdaptiveAvgPooling2D(x, output_size=1)
w = self.body(w) # (N, C, 1, 1) --> (N, C, 1, 1)
y = F.broadcast_mul(x, w)
return y
#---------------------------------------------------------------------------------
class Activation(nn.Activation):
def __init__(self, activation, **kwargs):
super(Activation, self).__init__(activation, **kwargs)
attrs = AttributeScope.current.attrs
self.mirroring_level = attrs.get('mirroring_level', 0)
def hybrid_forward(self, F, x):
_kwargs = dict()
if F is mx.symbol and self.mirroring_level >= 1:
_kwargs['force_mirroring'] = 'True'
return F.Activation(x, act_type=self._act_type, name='fwd', **_kwargs)
#---------------------------------------------------------------------------------
class AttributeScope(object):
current = None
def __init__(self, attrs={}):
self.old_scope = None
self.attrs = attrs
def __enter__(self):
self.old_scope = AttributeScope.current
attrs = AttributeScope.current.attrs.copy()
attrs.update(self.attrs)
self.attrs = attrs
AttributeScope.current = self
return self
def __exit__(self, ptype, value, trace):
assert self.old_scope
AttributeScope.current = self.old_scope
AttributeScope.current = AttributeScope()