天天看點

深度殘差收縮網絡:一種新的深度注意力機制算法(附代碼)

本文簡介了一種新的深度注意力算法,即深度殘差收縮網絡(Deep Residual Shrinkage Network)。從功能上講,深度殘差收縮網絡是一種面向強噪聲或者高度備援資料的特征學習方法。本文首先回顧了相關基礎知識,然後介紹了深度殘差收縮網絡的動機和具體實作,希望對大家有所幫助。

1.相關基礎

深度殘差收縮網絡主要建立在三個部分的基礎之上:深度殘差網絡、軟門檻值函數和注意力機制。

1.1深度殘差網絡

深度殘差網絡無疑是近年來最成功的深度學習算法之一,在谷歌學術上的引用已經突破四萬次。相較于普通的卷積神經網絡,深度殘差網絡采用跨層恒等路徑的方式,緩解了深層網絡的訓練難度。深度殘差網絡的主幹部分是由很多殘差子產品堆疊而成的,其中一種常見的殘差子產品如下圖所示。

深度殘差收縮網絡:一種新的深度注意力機制算法(附代碼)

1.2軟門檻值函數

軟門檻值函數是大部分降噪方法的核心步驟。首先,我們需要設定一個正數門檻值。該門檻值不能太大,即不能大于輸入資料絕對值的最大值,否則輸出會全部為零。然後,軟門檻值函數會将絕對值低于這個門檻值的輸入資料設定為零,并且将絕對值大于這個門檻值的輸入資料也朝着零收縮,其輸入與輸出的關系如下圖(a)所示。

深度殘差收縮網絡:一種新的深度注意力機制算法(附代碼)

軟門檻值函數的輸出y對輸入x的導數如上圖(b)所示。我們可以發現,其導數要麼取值為0,要麼取值為1。從這個角度看的話,軟門檻值函數和ReLU激活函數有一定的相似之處,也有利于深度學習算法訓練時梯度的反向傳播。值得注意的是,門檻值的選取對軟門檻值函數的結果有着直接的影響,至今仍是一個難題。

1.3注意力機制

注意力機制是近年來深度學習領域的超級研究熱點,而Squeeze-and-Excitation Network (SENet)則是最為經典的注意力算法之一。如下圖所示,SENet通過一個小型網絡學習得到一組權值系數,用于各個特征通道的權重。這其實是一種注意力機制:首先評估各個特征通道的重要程度,然後根據其重要程度賦予各個特征通道合适的權重。

深度殘差收縮網絡:一種新的深度注意力機制算法(附代碼)

如下圖所示,SENet可以與殘差子產品內建在一起。在這種模式下,由于跨層恒等路徑的存在,SENet可以更容易得到訓練。另外,值得指出的是,每個樣本的權值系數都是根據其自身設定的;也就是說,每個樣本都可以有自己獨特的一組權值系數。

深度殘差收縮網絡:一種新的深度注意力機制算法(附代碼)

2.深度殘差收縮網絡

接下來,本部分針對深度殘差收縮網絡的動機、實作、優勢和驗證,分别展開了介紹。

2.1動機

首先,大部分現實世界中的資料,包括圖檔、語音或者振動,都或多或少地含有噪聲或者備援資訊。從廣義上講,在一個樣本裡面,任何與目前模式識别任務無關的資訊,都可以被認為是噪聲或者備援資訊。這些噪聲或者備援資訊很可能會對目前的模式識别任務造成不利的影響。

其次,對于任意的兩個樣本,它們的噪聲或備援含量經常是不同的。換言之,有些樣本所含的噪聲或備援要多一些,有些要少一些。這就要求我們在設計算法的時候,應該使算法具備根據每個樣本的特點、單獨設定相關參數的能力。

在上述兩點的驅動下,我們能不能将傳統信号降噪算法中的軟門檻值函數引入深度殘差網絡之中呢?軟門檻值函數中的門檻值應該怎樣選取呢?深度殘差收縮網絡就給出了一種答案。

2.2實作

深度殘差收縮網絡融合了深度殘差網絡、SENet和軟門檻值函數。如下圖所示,深度殘差收縮網絡就是将殘差模式下的SENet中的“重新權重”替換成了“軟門檻值化”。在SENet中,所嵌入的小型網絡是用于擷取一組權值系數;在深度殘差收縮網絡中,該小型網絡則是用于擷取一組門檻值。

深度殘差收縮網絡:一種新的深度注意力機制算法(附代碼)

為了獲得合适的門檻值,相較于原始的SENet,深度殘差收縮網絡裡面的小型網絡的結構也進行了調整。具體而言,該小型網絡所輸出的門檻值,是(各個特征通道的絕對值的平均值)×(一組0和1之間的系數)。通過這種方式,深度殘差收縮網絡不僅確定了所有門檻值都為正數,而且門檻值不會太大(不會使所有輸出都為0)。

如下圖所示,深度殘差收縮網絡的整體結構與普通的深度殘差網絡是一緻的,包含了輸入層、剛開始的卷積層、一系列的基本子產品以及最後的全局均值池化和全連接配接輸出層等。

深度殘差收縮網絡:一種新的深度注意力機制算法(附代碼)

2.3優勢

首先,軟門檻值函數所需要的門檻值,是通過一個小型網絡自動設定的,避免了人工設定門檻值所需要的專業知識。

然後,深度殘差收縮網絡確定了軟門檻值函數的門檻值為正數,而且在合适的取值範圍之内,避免了輸出全部為零的情況。

同時,每個樣本都有自己獨特的一組門檻值,使得深度殘差收縮網絡适用于各個樣本的噪聲含量不同的情況。

3.結論

由于噪聲或者備援資訊是無處不在的,深度殘差收縮網絡,或者說這種“注意力機制”+“軟門檻值函數”的思路,或許有着廣闊的拓展空間和應用範圍。

4. Keras程式

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 28 23:24:05 2019
Implemented using TensorFlow 1.0.1 and Keras 2.2.1
 
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: super_9527
"""

from __future__ import print_function
import keras
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
from keras.optimizers import Adam
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
from keras.layers.core import Lambda
K.set_learning_phase(1)

# Input image dimensions
img_rows, img_cols = 28, 28

# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

# Noised data
x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)


def abs_backend(inputs):
    return K.abs(inputs)

def expand_dim_backend(inputs):
    return K.expand_dims(K.expand_dims(inputs,1),1)

def sign_backend(inputs):
    return K.sign(inputs)

def pad_backend(inputs, in_channels, out_channels):
    pad_dim = (out_channels - in_channels)//2
    inputs = K.expand_dims(inputs,-1)
    inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
    return K.squeeze(inputs, -1)

# Residual Shrinakge Block
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                             downsample_strides=2):
    
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
    
    for i in range(nb_blocks):
        
        identity = residual
        
        if not downsample:
            downsample_strides = 1
        
        residual = BatchNormalization()(residual)
        residual = Activation('relu')(residual)
        residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), 
                          padding='same', kernel_initializer='he_normal', 
                          kernel_regularizer=l2(1e-4))(residual)
        
        residual = BatchNormalization()(residual)
        residual = Activation('relu')(residual)
        residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', 
                          kernel_regularizer=l2(1e-4))(residual)
        
        # Calculate global means
        residual_abs = Lambda(abs_backend)(residual)
        abs_mean = GlobalAveragePooling2D()(residual_abs)
        
        # Calculate scaling coefficients
        scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', 
                       kernel_regularizer=l2(1e-4))(abs_mean)
        scales = BatchNormalization()(scales)
        scales = Activation('relu')(scales)
        scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales)
        scales = Lambda(expand_dim_backend)(scales)
        
        # Calculate thresholds
        thres = keras.layers.multiply([abs_mean, scales])
        
        # Soft thresholding
        sub = keras.layers.subtract([residual_abs, thres])
        zeros = keras.layers.subtract([sub, sub])
        n_sub = keras.layers.maximum([sub, zeros])
        residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub])
        
        # Downsampling (it is important to use the pooL-size of (1, 1))
        if downsample_strides > 1:
            identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity)
            
        # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution)
        if in_channels != out_channels:
            identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity)
        
        residual = keras.layers.add([residual, identity])
    
    return residual


# define and train a model
inputs = Input(shape=input_shape)
net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
net = residual_shrinkage_block(net, 1, 8, downsample=True)
net = BatchNormalization()(net)
net = Activation('relu')(net)
net = GlobalAveragePooling2D()(net)
outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))

# get results
K.set_learning_phase(0)
DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)
print('Train loss:', DRSN_train_score[0])
print('Train accuracy:', DRSN_train_score[1])
DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)
print('Test loss:', DRSN_test_score[0])
print('Test accuracy:', DRSN_test_score[1])           

複制

5. TFLearn程式

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 23 21:23:09 2019
Implemented using TensorFlow 1.0 and TFLearn 0.3.2
 
M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
 
@author: super_9527
"""
  
from __future__ import division, print_function, absolute_import
  
import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d
  
# Data loading
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()
  
# Add noise
X = X + np.random.random((50000, 32, 32, 3))*0.1
testX = testX + np.random.random((10000, 32, 32, 3))*0.1
  
# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y,10)
testY = tflearn.data_utils.to_categorical(testY,10)
  
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                   downsample_strides=2, activation='relu', batch_norm=True,
                   bias=True, weights_init='variance_scaling',
                   bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                   trainable=True, restore=True, reuse=False, scope=None,
                   name="ResidualBlock"):
      
    # residual shrinkage blocks with channel-wise thresholds
  
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
  
    # Variable Scope fix for older TF
    try:
        vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                   reuse=reuse)
    except Exception:
        vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
  
    with vscope as scope:
        name = scope.name #TODO
  
        for i in range(nb_blocks):
  
            identity = residual
  
            if not downsample:
                downsample_strides = 1
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3,
                             downsample_strides, 'same', 'linear',
                             bias, weights_init, bias_init,
                             regularizer, weight_decay, trainable,
                             restore)
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3, 1, 'same',
                             'linear', bias, weights_init,
                             bias_init, regularizer, weight_decay,
                             trainable, restore)
              
            # get thresholds and apply thresholding
            abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
            scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tflearn.batch_normalization(scales)
            scales = tflearn.activation(scales, 'relu')
            scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
            thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
            # soft thresholding
            residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
              
  
            # Downsampling
            if downsample_strides > 1:
                identity = tflearn.avg_pool_2d(identity, 1,
                                               downsample_strides)
  
            # Projection to new dimension
            if in_channels != out_channels:
                if (out_channels - in_channels) % 2 == 0:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch]])
                else:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
                in_channels = out_channels
  
            residual = residual + identity
  
    return residual
  
  
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
  
# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
  
# Build a Deep Residual Shrinkage Network with 3 blocks
net = tflearn.input_data(shape=[None, 32, 32, 3],
                         data_preprocessing=img_prep,
                         data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)
  
model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
  
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]           

複制

原文

M. Zhao, S. Zhong, X. Fu, B. Tang, and M. Pecht, “Deep Residual Shrinkage Networks for Fault Diagnosis,” IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898