天天看點

tf.add_to_collection用法

tf.add_to_collection:使用預設圖的Graph.add_to_collection()的包裝。

tf.compat.v1.add_to_collection(
    name, value
)
           

請參閱

tf.Graph.add_to_collection

 以擷取更多詳細資訊。

Args

name

集合的鍵。例如,

GraphKeys

類包含許多集合的标準名稱。

value

要添加到集合中的值。

Eager Compatibility

當在EagerVariableStore中建立變量時(例如,作為圖層或模闆的一部分),在eager隻支援集合。

  • tf.add_to_collection('list_name', element):将元素element添加到清單list_name中
  • tf.get_collection('list_name'):傳回名稱為list_name的清單,是個清單
  • tf.add_n(list):将清單元素相加并傳回一個張量
import tensorflow as tf

tf.add_to_collection('losses', tf.constant(2.2))
tf.add_to_collection('losses', tf.constant(3.))

with tf.Session() as sess:
    print(sess.run(tf.get_collection('losses')))
    print(sess.run(tf.add_n(tf.get_collection('losses'))))

# 輸出結果:
# [2.2, 3.0]
# 5.2
           
import tensorflow as tf

v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(0))
tf.add_to_collection('loss', v1)

v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))
tf.add_to_collection('loss', v2)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    get_coll = tf.get_collection('loss')
    print(get_coll)
    print(sess.run(get_coll))

    get_add_n = tf.add_n(tf.get_collection('loss'))
    print(get_add_n)
    print(sess.run(get_add_n))
    
    
# 輸出結果:    
# [<tf.Variable 'v1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(1,) dtype=float32_ref>]
# [array([0.], dtype=float32), array([2.], dtype=float32)]
# Tensor("AddN:0", shape=(1,), dtype=float32)
# [2.]