天天看点

CVPR 2021 | Activate or Not: Learning Customized Activation阅读分享(keras实现)摘要

旷视2021CVPR

摘要

提出了一个简单的,有效的,和一般的激活函数ACON,它学习是否需要激活神经元。发现Swish,最近流行的NAS搜索激活,可以解释为对ReLU的一个平滑近似。同样地,将更一般的Maxout方法近似ACON方法,并使Swish成为ACON的一个特殊情况。meta-ACON学习了优化非线性和线性之间的参数切换,并提供了一个新的设计空间。通过简单地改变激活函数,证明了它在小模型和高度优化的大模型上的有效性。在MobileNet-0.25和ResNet-152上,MmageNet前1的准确率分别提高了6.7%和1.8%)。此外,新ACON可以自然地转移到对象检测和语义分割上,表明ACON是各种任务中有效的替代选择。

论文主要思想

首先,作者认为一般的最大近似公式为:Smooth maximum

∑ ( x 1 , . . . , x n ) = ∑ i = 1 m x i e β x i ∑ i = 1 n e β x i \sum(x1,...,x_{n}) =\frac{\sum_{i=1}^m{x_{i}e^{\beta x_{i}}}}{\sum_{i=1}^n{e^{\beta x_{i}}}} ∑(x1,...,xn​)=∑i=1n​eβxi​∑i=1m​xi​eβxi​​

因为,当

β → ∞ , S β → m a x \beta \rightarrow \infty, S_{\beta} \rightarrow max β→∞,Sβ​→max 非线性

β → 0 , S β → m e a n \beta \rightarrow 0, S_{\beta} \rightarrow mean β→0,Sβ​→mean 线性

然后,考虑n=2的情况:

S β ( η a ( x ) , η b ( x ) ) = ( η a ( x ) − η b ( x ) ) ⋅ σ [ β ( η a ( x ) − η b ) ) ] + η b S_{\beta}(\eta_{a}(x),\eta_{b}(x))\\=(\eta_{a}(x)-\eta_{b}(x))\cdot \sigma[\beta(\eta_{a}(x)-\eta_{b}))]+\eta_{b} Sβ​(ηa​(x),ηb​(x))=(ηa​(x)−ηb​(x))⋅σ[β(ηa​(x)−ηb​))]+ηb​

围绕以下三种情况展开叙述:

CVPR 2021 | Activate or Not: Learning Customized Activation阅读分享(keras实现)摘要

(1)ACON-A提出了一个新的视角来将Swish理解为平滑的ReLU;

(2)ACON-B将平滑的ReLU的一般Maxout方法到Swish的一般ACON方法;

(3)ACON-C它能涵盖之前的,甚至是更复杂的形式,通过 p 1 和 p 2 p_{1}和p_{2} p1​和p2​两个可学习参数来自适应调整。

前面提到,ACON系列的激活函数通过 β \beta β的值来控制是否激活神经元( β \beta β 为0,算术平均为线性,即不激活)。因此需要为ACON设计一个计算 β \beta β的自适应函数。而自适应函数的设计空间包含了layer-wise,channel-wise,pixel-wise这三种空间,分别对应的是层,通道,像素。最后论文通过实验在通道上得到最佳效果,meta-ACON。

keras实现

以下是根据论文和Pytorch源码实现的keras版本。特征通道必须channel last。代码链接

def meta_acon(inputs, r=16):
    in_dims = int(inputs.shape[-1])
    temp_dims = max(r, in_dims//r)
    x = GlobalAvgPool2D()(inputs)
    x = Reshape((1, 1, in_dims))(x)
    x = Conv2D(temp_dims, 1)(x)
    x = BatchNormalization()(x)
    x = Conv2D(in_dims, 1)(x)
    x = BatchNormalization()(x)
    x = Activation(activation='sigmoid')(x)
    x = ACON_C()([inputs, x])
    return x
           
class ACON_C(Layer):
    """
    data_format: A string,
            input feature must be channel last
    """
    def __init__(self,
                 data_format=None,
                 p1_initializer='glorot_uniform',
                 p1_regularizer=None,
                 p1_constraint=None,
                 p2_initializer='glorot_uniform',
                 p2_regularizer=None,
                 p2_constraint=None,
                 **kwargs):
        super(ACON_C, self).__init__(**kwargs)
        self.supports_masking = True
        self.data_format = conv_utils.normalize_data_format(data_format)
        self.p1_initializer = initializers.get(p1_initializer)
        self.p1_regularizer = regularizers.get(p1_regularizer)
        self.p1_constraint = constraints.get(p1_constraint)
        self.p2_initializer = initializers.get(p2_initializer)
        self.p2_regularizer = regularizers.get(p2_regularizer)
        self.p2_constraint = constraints.get(p2_constraint)

    def build(self, input_shape):

        channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[0][channel_axis]
        kernel_shape = (1, 1) + (1, input_dim)
        self.p1 = self.add_weight(shape=kernel_shape,
                                  initializer=self.p1_initializer,
                                  name='kernel_p1',
                                  regularizer=self.p1_regularizer,
                                  constraint=self.p1_constraint)

        self.p2 = self.add_weight(shape=kernel_shape,
                                  initializer=self.p2_initializer,
                                  name='kernel_p2',
                                  regularizer=self.p2_regularizer,
                                  constraint=self.p2_constraint)

    def call(self, inputs):
        x, beta = inputs
        x1 = self.p1 * x
        x2 = self.p2 * x
        temp = x1 - x2
        x = temp * tf.sigmoid(beta*temp) + x2
        return x

    def get_config(self):
        config = {

            'data_format': self.data_format,
            'p1_initializer': initializers.serialize(self.p1_initializer),
            'p1_regularizer': regularizers.serialize(self.p1_regularizer),
            'p1_constraint': constraints.serialize(self.p1_constraint),
            'p2_initializer': initializers.serialize(self.p2_initializer),
            'p2_regularizer': regularizers.serialize(self.p2_regularizer),
            'p2_constraint': constraints.serialize(self.p2_constraint),
        }
        base_config = super(ACON_C, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


           

声明:本内容来源网络,版权属于原作者,图片来源原论文。如有侵权,联系删除。

参考文献

【1】论文地址

【2】:Pytorch代码