天天看點

TENSORFLOW 之 tf.variable_scope 以及 tf.get_variable

來源:

https://zhuanlan.zhihu.com/p/37711713

https://zhuanlan.zhihu.com/p/37922147

import tensorflow as tf

with tf.variable_scope('variable_scope_test'):
	v1 = tf.get_variable('v', shape=[1],  initializer=tf.constant_initializer(1.0))
	v2 = tf.Variable(tf.constant(1.0, shape=[1]), name='v')
	v3 = tf.Variable(tf.constant(1.0, shape=[1]), name='v')

with tf.Session() as sess:
	init_op = tf.global_variables_initializer()
        sess.run(init_op)
	print('the name of v1:', v1.name)
	print('the name of v2:', v2.name)
	print('the name of v3:', v3.name)
#輸出為
#the name of v1: variable_scope_test/v:0
#the name of v2: variable_scope_test/v_1:0
#the name of v3: variable_scope_test/v_2:0
           

tf.variable_scope會在原有的變量名的前面上加上變量空間

import tensorflow as tf

with tf.variable_scope('variable_scope_test'):
	v1 = tf.get_variable('v', shape=[1],  initializer=tf.constant_initializer(1.0))
	v2 = tf.Variable(tf.constant(1.0, shape=[1]), name='v')
	v3 = tf.Variable(tf.constant(1.0, shape=[1]), name='v')

with tf.variable_scope('variable_scope_test'):
	v4 = tf.get_variable('v', shape=[1])	

with tf.Session() as sess:
	init_op = tf.global_variables_initializer()
        sess.run(init_op)
	print('the name of v1:', v1.name)
	print('the name of v2:', v2.name)
	print('the name of v3:', v3.name)
	print('the name of v4:', v4.name)

# 輸出報錯:ValueError: Variable variable_scope_test/v already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?
           

報錯原因:在第二塊with tf.variable_scope('variable_scope_test'): 處又在variable_scope_test變量命名空間下定義了name為v的變量,也就是這裡(v4 = tf.get_variable('v', shape=[1]))重新定義了已存在的變量。

解決辦法如下:

import tensorflow as tf

with tf.variable_scope('variable_scope_test'):
	v1 = tf.get_variable('v', shape=[1],  initializer=tf.constant_initializer(1.0))
	v2 = tf.Variable(tf.constant(1.0, shape=[1]), name='v')
	v3 = tf.Variable(tf.constant(1.0, shape=[1]), name='v')

with tf.variable_scope('variable_scope_test', reuse=True): # set reuse=True
	v4 = tf.get_variable('v', shape=[1])	

with tf.Session() as sess:
	init_op = tf.global_variables_initializer()
	sess.run(init_op)
	print('the name of v1:', v1.name)
	print('the name of v2:', v2.name)
	print('the name of v3:', v3.name)
	print('the name of v4:', v4.name)
#輸出為
#the name of v1: variable_scope_test/v:0
#the name of v2: variable_scope_test/v_1:0
#the name of v3: variable_scope_test/v_2:0
#the name of v4: variable_scope_test/v:0
           

同理,以下情況也會報錯:

import tensorflow as tf

with tf.variable_scope('variable_scope_test'):
	v1 = tf.get_variable('v', shape=[1],  initializer=tf.constant_initializer(1.0))
	v2 = tf.Variable(tf.constant(1.0, shape=[1]), name='v')
	v3 = tf.Variable(tf.constant(1.0, shape=[1]), name='v')

with tf.variable_scope('variable_scope_test', reuse=True):
	v4 = tf.get_variable('v1', shape=[1])	

with tf.Session() as sess:
	init_op = tf.global_variables_initializer()
	sess.run(init_op)
	print('the name of v1:', v1.name)
	print('the name of v2:', v2.name)
	print('the name of v3:', v3.name)
	print('the name of v4:', v4.name)

# ValueError: Variable variable_scope_test/v1 does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
           

之是以報這個錯,是因為設定reuse=True之後在該變量命名空間内,tf.get_variable隻能擷取已存在的變量而不能建立新變量。但如果又想建立變量,又想重用變量即擷取變量呢?那可以用下面這個方法:

import tensorflow as tf

with tf.variable_scope('variable_scope_test'):
	v1 = tf.get_variable('v', shape=[1], initializer=tf.constant_initializer(1.0))
	v2 = tf.Variable(tf.constant(1.0, shape=[1]), name='v')
	v3 = tf.Variable(tf.constant(1.0, shape=[1]), name='v')

with tf.variable_scope('variable_scope_test') as scope:
	v4 = tf.get_variable('v1', shape=[1], initializer=tf.constant_initializer(1.0))
	scope.reuse_variables()
	v5 = tf.get_variable('v', shape=[1])	

with tf.Session() as sess:
	init_op = tf.global_variables_initializer()
	sess.run(init_op)
	print('the name of v1:', v1.name)
	print('the name of v2:', v2.name)
	print('the name of v3:', v3.name)
	print('the name of v4:', v4.name)
	print('the name of v5:', v5.name)

#輸出為
#the name of v1: variable_scope_test/v:0
#the name of v2: variable_scope_test/v_1:0
#the name of v3: variable_scope_test/v_2:0
#the name of v4: variable_scope_test/v1:0
#the name of v5: variable_scope_test/v:0
           

tf.get_variable結合作用域即可表明我們是想建立新的變量,還是共享變量。先前已經用同一個變量名通過get_variable函數執行個體化了變量),那麼get_variable隻會傳回之前的變量,否則才創造新的變量。

# 單個卷積層
def conv_relu(input, kernel_shape, bias_shape):

    # Create variable named "weights".
    weights = tf.get_variable("weights", kernel_shape,
        initializer=tf.random_normal_initializer())

    # Create variable named "biases".
    biases = tf.get_variable("biases", bias_shape,
        initializer=tf.constant_initializer(0.0))

    conv = tf.nn.conv2d(input, weights,
        strides=[1, 1, 1, 1], padding='SAME')

    return tf.nn.relu(conv + biases)

# 多個卷積層通過變量域來區分不同層的變量
# 變量名稱為:scope_name/variable_name
def my_image_filter(input_images):

    with tf.variable_scope("conv1"):
        # Variables created here will be named "conv1/weights", "conv1/biases".
        relu1 = conv_relu(input_images, [5, 5, 32, 32], [32])

    with tf.variable_scope("conv2"):
        # Variables created here will be named "conv2/weights", "conv2/biases".

        return conv_relu(relu1, [5, 5, 32, 32], [32])
           

但是如果多次調用my_image_filter函數也會報錯

result1 = my_image_filter(image1)
result2 = my_image_filter(image2)
# Raises ValueError(... conv1/weights already exists ...)
           

因為用

get_variable()

建立兩個相同名字的變量是會報錯的,get_variable() 隻檢查變量名,防止重複,如果要變量共享,就需要指定在哪個域名内可以共享變量。

方法一:采用scope.reuse_variables()

with tf.variable_scope("model") as scope:
  output1 = my_image_filter(input1)
  scope.reuse_variables() # 允許在該域中重用變量,是以下面的output2可以重用output1中的變量
  output2 = my_image_filter(input2)
           

方法二:

使用reuse=True建立具有相同名稱的作用域

with tf.variable_scope("model"):
  output1 = my_image_filter(input1)
with tf.variable_scope("model", reuse=True):
  output2 = my_image_filter(input2)
           

總結

  • tf.variable()每次被調用都建立相應的變量,即便變量名重複,也會建立新的變量,是以無法共享變量名;
  • tf.get_variable()預設隻檢查變量名,如果變量名重複,那麼就會報錯;
  • 如果scope沒有開啟共享變量(預設模式),那麼調用tf.get_variable()發現已有相同變量名的變量,就會報錯,如果沒有,就建立一個新的變量。
  • 如果scope中開啟共享變量,那麼調用tf.get_variable()就會查找相同變量名的變量,如果有,就直接傳回該變量,如果沒有,就建立一個新的變量;
  • 要重用變量,需要在scope中開啟共享變量,有兩種方法,推薦第一種

繼續閱讀