天天看點

深度學習基礎:圖文并茂細節到位batch normalization原理和實踐

作者:極客資料俠

關鍵字:batch normalization,tensorflow,批量歸一化

bn簡介

batch normalization批量歸一化,目的是對神經網絡的中間層的輸出進行一次額外的處理,經過處理之後期望每一層的輸出盡量都呈現出均值為0标準差是1的相同的分布上,進而保證每一層的輸出穩定不會劇烈波動,進而有效降低模型的訓練難度快速收斂,同時對大學習率的容忍度增強,避免了大學習率的梯度爆炸問題,是以配合大學習率能加快收斂,跳出不好的局部極值。

原理概括

bn的實作方法是:針對一個批次的資料,對網絡的隐藏層(中間層)的輸出做批量歸因化操作,該操作包括兩個部分:

  • 1.标準化:對一批次資料在中間層的`每個神經元`的輸出進行标準化,一個資料一個神經元隻有一個輸出,一組資料一個神經元就是一個一維向量,對該向量每個值減去均值除标準差,這塊的操作完全根據輸入資料決定或者說由上層操作決定,該層或者說神經網絡不設定參數學習
  • 2.分布還原:在完成标準化之後,再學習一個線性映射(wx+b)中的w和b,再把标準化的輸出進行映射,具體怎麼映射由神經網絡自己學習得到。

操作如下圖所示

深度學習基礎:圖文并茂細節到位batch normalization原理和實踐

bn操作兩步走示意圖

标準化很好了解,為啥要在标準化之後再把分布拉回來,那不是做了無用功嗎,不是的。标準化破壞了資料的原始分布,可能導緻輸入給下遊非線性函數比如激活函數的時候産生負面效果,是以加入還原線性變換進行适度還原,所謂适度還原就是不用擔心資料的原始分布被破壞導緻影響網絡訓練的問題,因為還有個還原層,它的上限就是把标準化徹底還原成原始分布(w是标準差,b是均值),下限就是保留标準化,中間水準就是把标準化的結果稍微拉動一下,具體還原函數還原到什麼程度完全由網絡自行學習決定,相當于人為讓分布标準化統一,也給模型留了個口子如果這個人為動作不合理就打退回去,等同于有自動審查機制的每層分布統一,在不影響模型學習的情況下盡量讓分布統一。

兩個階段的各自目的說明白之後還有幾個重點問題沒有解決:

  • 在哪個中間層進行bn:通常在全連接配接層或者卷積層後面,在激活函數之前,也有說在激活函數之後的,反正在上一層的輸出到下一層的輸入之間進行
  • bn要學習的參數:一個bn是一層網絡,接受比如一層全連接配接的輸出,其中全連接配接的每個神經元的輸出的均值和标準化是由輸入決定的不需要學習,而bn要學習是第二階段還原函數的w和b,w向量和b向量是在一層bn中共享的
  • 訓練和測試bn怎麼保持一緻:訓練過程以每個批次的實際均值和标準差進行标準化,而預測/測試過程不論是單樣本預測還是批量預測,都是以整個訓練過程的均值和标準差進行标準化,bn采用了`滑動平均`的方式随着訓練批次的進去,慢慢逼近訓練集全集在每一個bn層上的均值和标準差,最終在每一層bn上都有一組均值向量和标準差向量,相當于每一層每一個神經元通過滑動平均的方法都記錄下了最終的各自的均值和标準差。

圖示計算過程

先看下公式

深度學習基礎:圖文并茂細節到位batch normalization原理和實踐

bn計算公式

公式和之前講的多一個小e,通常是一個極小的數比如1e-3,目的防止分母為0。

深度學習基礎:圖文并茂細節到位batch normalization原理和實踐

深入了解bn圖示

公式太抽象,畫一下bn計算和應用的圖示,如上圖是一個(None, 3)的輸入到一層全連接配接(3, 4)之後加入bn再到激活函數的資料流轉情況

滑動平均估算整體均值标準差

模型訓練好之後,需要計算訓練集全集在各個bn層的均值标準差權重向量,實作方法是在訓練過程中記錄中間狀态不斷修正調整逼近結果,這樣等模型訓練完,這個結果也記錄在網絡變量中,在預測的時候直接調用即可。具體公式如下

mean_value = mean_value * decay + batch_mean * (1 - decay)
var_value = var_value * decay + batch_var * (1 - decay)           

其中mean_value吸收了訓練資料之後每一輪都會更新,初始值是0,同理var_value,他的初始值是1,decay推薦是0.9或者0.99,舉例

a = [1, 2, 1.5, 2, 3, 3, 3.2, 1.2, 0.5, 0.8, 0.3, 2, 2.1, 2.2, 1.6] * 100
sum(a) / len(a)  # 1.759999999999992
mean_value = 0
for i in a:
    mean_value = mean_value * 0.99 + i * 0.01  # 1.755411190749514           

理論上batch越多結果越接近真實,另外decay越大越穩定,decay越小新加入的batch mean占比重大波動越大,推薦0.9以上是求穩定,是以需要更多的batch,這樣才能避免還沒有畢竟真實就停止計算了,導緻測試集的參考均值和方差不準。

tensorflow實作方法

推薦使用from tensorflow.contrib.layers.python.layers import batch_norm,傳入需要bn的tensor,将是否是訓練還是測試/預測也作為一個tensor傳入進去,通過tf.cond+布爾标量實作邏輯判斷,其中訓練中batch_norm的參數is_training=True,預測is_training=False,另外測試時reuse=True,表示訓練和預測網絡中共享這個BN層,否則會出現兩個BN層,預測時拿得是初始化的w=0和b=1,導緻預測集效果出問題。

def batch_norm_layer(value, is_training, scope):
    def batch_statistics():
        return batch_norm(value, decay=0.9, updates_collections=tf.GraphKeys.UPDATE_OPS, is_training=True, scope=scope)

    def population_statistics():
        return batch_norm(value, decay=0.9, updates_collections=tf.GraphKeys.UPDATE_OPS, is_training=False, reuse=True, scope=scope)

    return tf.cond(is_training, batch_statistics, population_statistics)           

tf.cond輸入一個布爾tensor,一個true傳回函數,一個false傳回函數,舉例

a = tf.convert_to_tensor([1, 2, 3])
b = tf.constant(3)
c = tf.constant(False)
result = tf.cond(c, lambda: tf.add(a, a), lambda: tf.square(b))

with tf.Session() as sess:
    res = sess.run(result)
    print(res)  # 9,如果c是True傳回[2 4 6]           

除此之外還需要以下八股文代碼

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    # loss的例子
    self.train_step = optimizer.minimize(self.loss, global_step=self.global_step)           

這段代表表示訓練過程中記錄的滑動平均均值和标準差這個操作存儲在tf.GraphKeys.UPDATE_OPS中,每次進行一次loss計算需要也計算一遍,希望在計算loss之前也把滑動平均也計算一邊是以采用tf.control_dependencies,表示with下文的内容必須在control_dependencies後面的條件完成之後才能運作。如果不加這個網絡僅僅計算loss操作,中間滑動平均根本不計算,這不影響訓練因為訓練用不到,但是幾乎徹底摧毀預測,導緻預測的參考均值标準差都是初始值。另一種就是把update_ops拿出來和train_step合并成一個最終的訓練操作

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
    train_ops = [train_step] + update_ops
    train_op_final = tf.group(*train_ops)           

這種不區分執行的前後順序了,好像也可以。

代碼實戰

以下代碼測試一下batch_norm在一個網絡中的應用,分别對比有bn和沒有bn的網絡訓練情況

import os
import time
import shutil
import pickle
import tensorflow as tf

from tensorflow.contrib.layers.python.layers import batch_norm
from tensorflow.python.saved_model import tag_constants

from preprocessing import get_batch


def batch_norm_layer(value, is_training, scope):
    def batch_statistics():
        return batch_norm(value, decay=0.9, updates_collections=tf.GraphKeys.UPDATE_OPS, is_training=True, scope=scope)

    def population_statistics():
        return batch_norm(value, decay=0.9, updates_collections=tf.GraphKeys.UPDATE_OPS, is_training=False, reuse=True, scope=scope)

    return tf.cond(is_training, batch_statistics, population_statistics)


class Model(object):
    def __init__(self, num_class, feature_size, learning_rate=0.5, weight_decay=0.01, decay_learning_rate=0.99):
        self.input_x = tf.placeholder(tf.float32, [None, feature_size], name="input_x")
        self.input_y = tf.placeholder(tf.float32, [None, num_class], name="input_y")
        self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
        self.batch_normalization = tf.placeholder(tf.bool, name="batch_normalization")
        self.global_step = tf.Variable(0, name="global_step", trainable=False)

        with tf.name_scope('layer_1'):
            dense_out_1 = tf.layers.dense(self.input_x, 64)
            # add
            dense_out_1 = batch_norm_layer(dense_out_1, is_training=self.batch_normalization, scope="bn1")
            dense_out_act_1 = tf.nn.relu(dense_out_1)

        with tf.name_scope('layer_2'):
            dense_out_2 = tf.layers.dense(dense_out_act_1, 32)
            # add
            dense_out_2 = batch_norm_layer(dense_out_2, is_training=self.batch_normalization, scope="bn2")
            dense_out_act_2 = tf.nn.relu(dense_out_2)

        with tf.name_scope('layer_out'):
            self.output = tf.layers.dense(dense_out_act_2, 2)
            self.probs = tf.nn.softmax(self.output, dim=1, name="probs")

        with tf.name_scope('loss'):
            self.loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.output, labels=self.input_y))
            vars = tf.trainable_variables()
            loss_l2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if
                                v.name not in ['bias', 'gamma', 'b', 'g', 'beta']]) * weight_decay
            self.loss += loss_l2

        with tf.name_scope("optimizer"):
            if decay_learning_rate:
                learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 100, decay_learning_rate)
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self.train_step = optimizer.minimize(self.loss, global_step=self.global_step)
            # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            # if update_ops:
            #     train_ops = [self.train_step] + update_ops
            #     self.train_op_final = tf.group(*train_ops)

        with tf.name_scope("metrics"):
            self.accuracy = tf.reduce_mean(
                tf.cast(tf.equal(tf.arg_max(self.probs, 1), tf.arg_max(self.input_y, 1)), dtype=tf.float32))


if __name__ == '__main__':
    train_x, train_y = pickle.load(
        open("/home/myproject/BatchNormalizationTest/batch_normalization_test/data/train.pkl", "rb"))
    test_x, test_y = pickle.load(
        open("/home/myproject/BatchNormalizationTest/batch_normalization_test/data/test.pkl", "rb"))

    tf.reset_default_graph()
    model = Model(num_class=2, feature_size=15, weight_decay=0)
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)

    # all_variable = tf.global_variables()
    # for variable in all_variable:
    #     if "moving" in variable.name:
    #         print(variable.name, variable.eval())
    BASIC_PATH = "/home/myproject/BatchNormalizationTest/batch_normalization_test"
    with tf.Session() as sess:
        init_op = tf.group(tf.global_variables_initializer())
        sess.run(init_op)

        train_batch = get_batch(10, 64, train_x, train_y)
        train_loss_list = []
        train_step_cnt = []
        acc_list = []
        val_feed_dict = {model.input_x: test_x, model.input_y: test_y, model.dropout_keep_prob: 1,
                         model.batch_normalization: False}
        for batch in train_batch:
            epoch, batch_x, batch_y = batch
            feed_dict = {model.input_x: batch_x, model.input_y: batch_y, model.dropout_keep_prob: 1,
                         model.batch_normalization: True}
            _, step, loss_train = sess.run([model.train_step, model.global_step, model.loss], feed_dict=feed_dict)
            train_loss_list.append(loss_train)
            train_step_cnt.append(step)
            if step % 10 == 0:
                print("epoch:", epoch + 1, "step:", step, "loss:", loss_train)
                # ckpt
                saver.save(sess, os.path.join(BASIC_PATH, "./ckpt2/ckpt"))

            if step % 50 == 0:
                loss_val, acc_val, probs = sess.run([model.loss, model.accuracy, model.probs], feed_dict=val_feed_dict)
                print("{:-^30}".format("evaluation"))
                print("[evaluation]", "loss:", loss_val, "acc", acc_val)

        loss_val, acc_val, probs = sess.run([model.loss, model.accuracy, model.probs], feed_dict=val_feed_dict)
        print("{:-^30}".format("evaluation"))
        print("[evaluation]", "loss:", loss_val, "acc", acc_val)

    import matplotlib.pyplot as plt

    plt.plot(train_step_cnt, train_loss_list)
    plt.ylim([0.25, 2])
    plt.show()

    # save
    pb_num = str(int(time.time()))
    pb_path = os.path.join(BASIC_PATH, "./tfserving2", pb_num)
    shutil.rmtree(pb_path, ignore_errors=True)
    tf.reset_default_graph()
    with tf.Session() as sess:
        last_ckpt = tf.train.latest_checkpoint(os.path.join(BASIC_PATH, "./ckpt2"))
        print("讀取ckpt: {}".format(last_ckpt))
        saver = tf.train.import_meta_graph("{}.meta".format(last_ckpt))
        saver.restore(sess, last_ckpt)
        graph = tf.get_default_graph()
        # get tensor
        input_x = graph.get_tensor_by_name("input_x:0")
        dropout_keep_prob = graph.get_tensor_by_name("dropout_keep_prob:0")
        batch_norm_is_train = graph.get_tensor_by_name("batch_normalization:0")
        pred = graph.get_tensor_by_name("layer_out/probs:0")
        builder = tf.saved_model.builder.SavedModelBuilder(pb_path)
        inputs = {'input_x': tf.saved_model.utils.build_tensor_info(input_x),
                  'dropout_keep_prob': tf.saved_model.utils.build_tensor_info(dropout_keep_prob),
                  'batch_norm': tf.saved_model.utils.build_tensor_info(batch_norm_is_train)
                  }
        outputs = {'output': tf.saved_model.utils.build_tensor_info(pred)}
        signature = tf.saved_model.signature_def_utils.build_signature_def(
            inputs=inputs,
            outputs=outputs,
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

        builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], {'my_signature': signature})
        builder.save()
    print("pb檔案儲存完成:", pb_num)           

代碼是兩個全連接配接層,每層線上性之後加入bn,bn之後是relu,其中bn隻使用均值偏移,不是用标準差,采用0.5的學習率運作結果如下

----------evaluation----------
[evaluation] loss: 0.48918334 acc 0.75525653
epoch: 10 step: 2210 loss: 0.38185906
epoch: 10 step: 2220 loss: 0.47259575
epoch: 10 step: 2230 loss: 0.62081766
epoch: 10 step: 2240 loss: 0.5432115
----------evaluation----------
[evaluation] loss: 0.49052832 acc 0.76843286           

同樣采用0.5的學習率,去除兩個bn層之後訓練結果如下

----------evaluation----------
[evaluation] loss: 0.6970511 acc 0.50602746
epoch: 10 step: 2210 loss: 0.71779585
epoch: 10 step: 2220 loss: 0.7110723
epoch: 10 step: 2230 loss: 0.72440666
epoch: 10 step: 2240 loss: 0.7036286
----------evaluation----------
[evaluation] loss: 0.6983369 acc 0.50602746           

對比下有bn和沒有bn在采用一個比較大的學習率的時候訓練階段網絡的收斂情況,先看全部訓練step,明顯發現最左側快速收斂階段使用bn比不是用bn更薄,後期不是用bn loss直接起飛,估計梯度爆炸了

深度學習基礎:圖文并茂細節到位batch normalization原理和實踐

收斂對比

再看前幾輪快速收斂階段,不是用bn前200輪還沒有收斂到0.6以下,使用bn已經收斂到最好0.4維持在0.5左右

深度學習基礎:圖文并茂細節到位batch normalization原理和實踐

收斂對比2

如果學習率回到正常比如0.01,兩個網絡的效果沒有明顯差別

預測階段和驗證一樣,設定一個布爾占位符給道False即可

import os
import pickle

from sklearn.metrics import accuracy_score
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants

BASIC_PATH = "/home/myproject/BatchNormalizationTest/batch_normalization_test"


def predict_pb(input_x_value, pb_file_no=None):
    """從pb導入模型"""
    max_time = pb_file_no
    if max_time is None:
        max_time = max(os.listdir(os.path.join(BASIC_PATH, "./tfserving2")))
    # max_time = "1672132226"
    print("讀取pb版本:", max_time)
    with tf.Session(graph=tf.Graph()) as sess:
        tf.saved_model.loader.load(sess, [tag_constants.SERVING], os.path.join(BASIC_PATH, "./tfserving2", max_time))
        graph = tf.get_default_graph()
        input_x = graph.get_operation_by_name("input_x").outputs[0]
        dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
        batch_norm_is_train = graph.get_operation_by_name("batch_normalization").outputs[0]
        probs = graph.get_tensor_by_name("layer_out/probs:0")

        pred = sess.run(probs, feed_dict={input_x: input_x_value, dropout_keep_prob: 1.0, batch_norm_is_train: False})

    return pred


if __name__ == '__main__':
    test_x, test_y = pickle.load(
        open("/home/myproject/BatchNormalizationTest/batch_normalization_test/data/test.pkl", "rb"))
    pred = predict_pb(test_x).tolist()
    pred = [0 if x[1] < 0.5 else 1 for x in pred]
    test_y = [0 if x == [1, 0] else 1 for x in test_y]
    print(accuracy_score(test_y, pred))           

參考文章:

batchnorm BN無法更新儲存參數 moving_mean/variance_batch_norm儲存失效_文草彙的三色堇的部落格-CSDN部落格

slim.batch_norm無法更新以及儲存參數_DRACO于的部落格-CSDN部落格

batch normalization的原理和作用_了解Batch Normalization系列1——原理(清晰解釋)_weixin_39627201的部落格-CSDN部落格

繼續閱讀