天天看點

Batch Normalization、Instance normalization簡單了解

1. Batch Normalization

首先,簡短介紹一下Batch Normalization,通常Batch Normalization更為大家所知,是以在此簡要介紹BN來引入Instance Normalization。

引入BN層主要是為了解決"Internal Covariate Shift"問題,關于這個問題李宏毅老師有個視訊講解比較形象[4],可以參考。Batch Normalization主要是作用在batch上,對NHW做歸一化,對小batchsize效果不好,添加了BN層能加快模型收斂,一定程度上還有的dropout的作用。

BN的基本思想其實相當直覺:因為深層神經網絡在做非線性變換前的激活輸入值(就是那個x=WU+B,U是輸入)随着網絡深度加深或者在訓練過程中,其分布逐漸發生偏移或者變動,之是以訓練收斂慢,一般是整體分布逐漸往非線性函數的取值區間的上下限兩端靠近(對于Sigmoid函數來說,意味着激活輸入值WU+B是大的負值或正值),是以這導緻反向傳播時低層神經網絡的梯度消失,這是訓練深層神經網絡收斂越來越慢的本質原因,而BN就是通過一定的規範化手段,把每層神經網絡任意神經元這個輸入值的分布強行拉回到均值為0方差為1的标準正态分布,其實就是把越來越偏的分布強制拉回比較标準的分布,這樣使得激活輸入值落在非線性函數對輸入比較敏感的區域,這樣輸入的小變化就會導緻損失函數較大的變化,意思是這樣讓梯度變大,避免梯度消失問題産生,而且梯度變大意味着學習收斂速度快,能大大加快訓練速度。

在BN論文中有下面這樣一幅圖,比較清楚的表示了BN具體是怎麼操作的:

Batch Normalization、Instance normalization簡單了解

前三步就是對一個batch内的資料進行歸一化,使得資料分布一緻:沿着通道計算每個batch的均值,計算每個batch的方差,對 X X X做歸一化。重點在第四步,**加入縮放和平移參數 γ , β \gamma,\beta γ,β **。這兩個參數可以通過學習得到,增加這兩個參數的主要目的是完成歸一化之餘,還要保留原來學習到的特征。

總結如下:

  • 沿着通道計算每個batch的均值 μ \mu μ
  • 沿着通道計算每個batch的方差 σ 2 \sigma ^ 2 σ2
  • 對x做歸一化, x ′ = ( x − μ ) / σ 2 + ϵ x' = (x-\mu) / \sqrt{\sigma ^2 + \epsilon} x′=(x−μ)/σ2+ϵ ​
  • 加入縮放和平移變量 γ \gamma γ和 β \beta β ,歸一化後的值, y = γ x ′ + β y=\gamma x' + \beta y=γx′+β

2. Instance Normalization

IN和BN最大的差別是,IN作用于單張圖檔,BN作用于一個batch。IN多适用于生成模型中,例如風格遷移。像風格遷移這類任務,每個像素點的資訊都非常重要,BN就不适合這類任務。BN歸一化考慮了一個batch中所有圖檔,這樣會令每張圖檔中特有的細節丢失。IN對HW做歸一化,同時保證了每個圖像執行個體之間的獨立。

論文中所給的公式如下:

Batch Normalization、Instance normalization簡單了解

總結如下:

  • 沿着通道計算每張圖的均值 μ \mu μ
  • 沿着通道計算每張圖的方差 σ 2 \sigma ^ 2 σ2
  • 對x做歸一化, x ′ = ( x − μ ) / σ 2 + ϵ x' = (x-\mu) / \sqrt{\sigma ^2 + \epsilon} x′=(x−μ)/σ2+ϵ ​
  • 加入縮放和平移變量 γ \gamma γ和 β \beta β ,歸一化後的值, y = γ x ′ + β y=\gamma x' + \beta y=γx′+β
def Instancenorm(x, gamma, beta):

    # x_shape:[B, C, H, W]
    results = 0.
    eps = 1e-5

    x_mean = np.mean(x, axis=(2, 3), keepdims=True)
    x_var = np.var(x, axis=(2, 3), keepdims=True0)
    x_normalized = (x - x_mean) / np.sqrt(x_var + eps)
    results = gamma * x_normalized + beta
    return results           

複制

pytorch中使用BN和IN:

class IBNorm(nn.Module):
    """ Combine Instance Norm and Batch Norm into One Layer
    """

    def __init__(self, in_channels):
        super(IBNorm, self).__init__()
        in_channels = in_channels
        self.bnorm_channels = int(in_channels / 2)
        self.inorm_channels = in_channels - self.bnorm_channels 

        self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
        self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) # IN,多用于風格遷移
        
    def forward(self, x):
        bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())  
        in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) 

        return torch.cat((bn_x, in_x), 1)           

複制

下圖來自何凱明大神2018年的論文Group Normalization[3],可以說很直覺了。

Batch Normalization、Instance normalization簡單了解

Reference:

[1] Ioffe S, Szegedy C. Batch normalization: Accelerating deep network training by reducing internal covariate shift[C]//International conference on machine learning. PMLR, 2015: 448-456.

[2] Ulyanov D, Vedaldi A, Lempitsky V. Instance normalization: The missing ingredient for fast stylization[J]. arXiv preprint arXiv:1607.08022, 2016.

[3] Wu Y, He K. Group normalization[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 3-19.

[4] 【深度學習李宏毅 】 Batch Normalization (中文)

[4] *深入了解Batch Normalization批标準化

[5] Batch Normalization原理與實戰

[6] *BatchNormalization、LayerNormalization、InstanceNorm、GroupNorm、SwitchableNorm總結

GroupNorm、SwitchableNorm總結](https://blog.csdn.net/liuxiao214/article/details/81037416)