版權聲明:本文為部落客原創文章,如需轉載,請注明出處, 謝謝。
https://blog.csdn.net/u014061630/article/details/85104491
Batch Normalization 作為深度學習中一個常用層,掌握其的使用非常重要,本部落格将梳理下各種 Batch Normalization API 的一些坑。
如果你對 Batch Normalization 還不清楚,可以檢視之前的部落格 Inception v2/BN-Inception:Batch Normalization 論文筆記 來學習下 Batch Normalization。
在 TensorFlow 中,Batch Normlization 有以下幾個實作(API):
-
tf.layers.BatchNormalization
-
tf.layers.batch_normalization
-
tf.keras.layers.BatchNormalization
-
tf.nn.batch_normalization
上述四個 API 按層次可分為兩類:
- 高階 API:
、tf.layers.BatchNormalization
、tf.layers.batch_normalization
tf.keras.layers.BatchNormalization
- 低階 API:
tf.nn.batch_normalization
上述四個 API 按行為可以分為兩類:
- TensorFlow API:
、tf.layers.BatchNormalization
、tf.layers.batch_normalization
tf.nn.batch_normalization
- Keras API:
tf.keras.layers.BatchNormalization
1. 頭号大坑 ----- 沒有調用 update_ops
Batch Normalization 中需要計算移動平均值,是以 BN 中有一些
update_ops
,在訓練中需要通過
tf.control_dependencies()
來添加對
update_ops
的調用:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = tf.train.AdamOptimizer(lrn_rate).minimize(cost)
在使用
update_ops
前,需要将 BN 層的 update_ops 添加到
tf.GraphKeys.UPDATE_OPS
這個 collection 中。
tf.layers.BatchNormalization
和
tf.layers.batch_normalization
會自動将 update_ops 添加到
tf.GraphKeys.UPDATE_OPS
這個 collection 中(注:
training
參數為 True 時,才會添加,False 時不添加)。
你以為躲過頭号坑就完了,哈哈哈,還有二号坑!
2. 二号坑(已修複) ----- 在 TensorFlow 訓練 session 中使用 tf.keras.layers.BatchNormalization
tf.keras.layers.BatchNormalization
不會自動将 update_ops 添加到
tf.GraphKeys.UPDATE_OPS
這個 collection 中。是以在 TensorFlow 訓練 session 中使用
tf.keras.layers.BatchNormalization
時,需要手動将 keras.BatchNormalization 層的
updates
添加到
tf.GraphKeys.UPDATE_OPS
這個 collection 中。
   \;
檢測是否添加了 updates 的方法:import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
Keras 是以面向對象的方式編寫的,是以我們可以通過以下方式擷取 BN 層的 updates,并将其添加到
tf.GraphKeys.UPDATE_OPS
這個 collection 中:
x = tf.placeholder("float",[None,32,32,3])
bn1 = tf.keras.layers.BatchNormalization()
y = bn1(x, training=True) # 調用後updates屬性才會有内容。
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates)
那麼如果我已經以下列形式編寫完了代碼,該怎麼将 BN 的 updates 添加到
tf.GraphKeys.UPDATE_OPS
這個 collection 中呢?
程式都這樣編寫了,難道我們就必須更改所有的 BN 層的編寫方式嗎?
答:不要改代碼了,直接從 ops 中過濾出 updates_ops,然後添加到指定的 collectuon 中即可。代碼如下:
ops = tf.get_default_graph().get_operations()
update_ops = [x for x in ops if ("AssignMovingAvg" in x.name and x.type=="AssignSubVariableOp")]
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_ops)
通過一些技術,我們已經将 BN 層的 updates 添加到了指定的 collection 中,然後按照
1.的方式處理下即可。
經過上面的努力,以後終于可以在 TF-session 中順暢地使用tf.keras.layers
API 了,真高興!(不要忘記點贊哦) 3. 附錄 其他關于 TF - BN 使用的一些資料:
- tensorflow batch_normalization的正确使用姿勢