天天看點

TensorFlow 中 Batch Normalization API 的一些坑1. 頭号大坑 ----- 沒有調用 update_ops2. 二号坑(已修複) ----- 在 TensorFlow 訓練 session 中使用 tf.keras.layers.BatchNormalization

版權聲明:本文為部落客原創文章,如需轉載,請注明出處, 謝謝。

https://blog.csdn.net/u014061630/article/details/85104491

Batch Normalization 作為深度學習中一個常用層,掌握其的使用非常重要,本部落格将梳理下各種 Batch Normalization API 的一些坑。

如果你對 Batch Normalization 還不清楚,可以檢視之前的部落格 Inception v2/BN-Inception:Batch Normalization 論文筆記 來學習下 Batch Normalization。

在 TensorFlow 中,Batch Normlization 有以下幾個實作(API):

  • tf.layers.BatchNormalization

  • tf.layers.batch_normalization

  • tf.keras.layers.BatchNormalization

  • tf.nn.batch_normalization

上述四個 API 按層次可分為兩類:

  • 高階 API:

    tf.layers.BatchNormalization

    tf.layers.batch_normalization

    tf.keras.layers.BatchNormalization

  • 低階 API:

    tf.nn.batch_normalization

上述四個 API 按行為可以分為兩類:

  • TensorFlow API:

    tf.layers.BatchNormalization

    tf.layers.batch_normalization

    tf.nn.batch_normalization

  • Keras API:

    tf.keras.layers.BatchNormalization

1. 頭号大坑 ----- 沒有調用 update_ops

Batch Normalization 中需要計算移動平均值,是以 BN 中有一些

update_ops

,在訓練中需要通過

tf.control_dependencies()

來添加對

update_ops

的調用:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = tf.train.AdamOptimizer(lrn_rate).minimize(cost)
           

在使用

update_ops

前,需要将 BN 層的 update_ops 添加到

tf.GraphKeys.UPDATE_OPS

這個 collection 中。

tf.layers.BatchNormalization

tf.layers.batch_normalization

會自動将 update_ops 添加到

tf.GraphKeys.UPDATE_OPS

這個 collection 中(注:

training

參數為 True 時,才會添加,False 時不添加)。

你以為躲過頭号坑就完了,哈哈哈,還有二号坑!

2. 二号坑(已修複) ----- 在 TensorFlow 訓練 session 中使用 tf.keras.layers.BatchNormalization

tf.keras.layers.BatchNormalization

不會自動将 update_ops 添加到

tf.GraphKeys.UPDATE_OPS

這個 collection 中。是以在 TensorFlow 訓練 session 中使用

tf.keras.layers.BatchNormalization

時,需要手動将 keras.BatchNormalization 層的

updates

添加到

tf.GraphKeys.UPDATE_OPS

這個 collection 中。

   \;

檢測是否添加了 updates 的方法:
import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
           

Keras 是以面向對象的方式編寫的,是以我們可以通過以下方式擷取 BN 層的 updates,并将其添加到

tf.GraphKeys.UPDATE_OPS

這個 collection 中:

x = tf.placeholder("float",[None,32,32,3])
bn1 = tf.keras.layers.BatchNormalization()
y = bn1(x, training=True) # 調用後updates屬性才會有内容。
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates)
           

那麼如果我已經以下列形式編寫完了代碼,該怎麼将 BN 的 updates 添加到

tf.GraphKeys.UPDATE_OPS

這個 collection 中呢?

程式都這樣編寫了,難道我們就必須更改所有的 BN 層的編寫方式嗎?

答:

不要改代碼了,直接從 ops 中過濾出 updates_ops,然後添加到指定的 collectuon 中即可。代碼如下:

ops = tf.get_default_graph().get_operations()
update_ops = [x for x in ops if ("AssignMovingAvg" in x.name and x.type=="AssignSubVariableOp")]
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_ops)
           

通過一些技術,我們已經将 BN 層的 updates 添加到了指定的 collection 中,然後按照

1.

的方式處理下即可。

經過上面的努力,以後終于可以在 TF-session 中順暢地使用

tf.keras.layers

API 了,真高興!(不要忘記點贊哦)
3. 附錄

其他關于 TF - BN 使用的一些資料:

  • tensorflow batch_normalization的正确使用姿勢

繼續閱讀