tf.add_to_collection:使用預設圖的Graph.add_to_collection()的包裝。
tf.compat.v1.add_to_collection(
name, value
)
請參閱
tf.Graph.add_to_collection
以擷取更多詳細資訊。
Args | |
---|---|
| 集合的鍵。例如, 類包含許多集合的标準名稱。 |
| 要添加到集合中的值。 |
Eager Compatibility
當在EagerVariableStore中建立變量時(例如,作為圖層或模闆的一部分),在eager隻支援集合。
- tf.add_to_collection('list_name', element):将元素element添加到清單list_name中
- tf.get_collection('list_name'):傳回名稱為list_name的清單,是個清單
- tf.add_n(list):将清單元素相加并傳回一個張量
import tensorflow as tf
tf.add_to_collection('losses', tf.constant(2.2))
tf.add_to_collection('losses', tf.constant(3.))
with tf.Session() as sess:
print(sess.run(tf.get_collection('losses')))
print(sess.run(tf.add_n(tf.get_collection('losses'))))
# 輸出結果:
# [2.2, 3.0]
# 5.2
import tensorflow as tf
v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(0))
tf.add_to_collection('loss', v1)
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))
tf.add_to_collection('loss', v2)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
get_coll = tf.get_collection('loss')
print(get_coll)
print(sess.run(get_coll))
get_add_n = tf.add_n(tf.get_collection('loss'))
print(get_add_n)
print(sess.run(get_add_n))
# 輸出結果:
# [<tf.Variable 'v1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(1,) dtype=float32_ref>]
# [array([0.], dtype=float32), array([2.], dtype=float32)]
# Tensor("AddN:0", shape=(1,), dtype=float32)
# [2.]