天天看點

tensorflow的variable、variable_scope和get_variable的用法和差別

在tensorflow中,可以使用tf.Variable來建立一個變量,也可以使用tf.get_variable來建立一個變量,但是在一個模型需要使用其他模型的變量時,tf.get_variable就派上大用場了。

先分别介紹兩個函數的用法:

import tensorflow as tf
var1 = tf.Variable(1.0,name='firstvar')
print('var1:',var1.name)
var1 = tf.Variable(2.0,name='firstvar')
print('var1:',var1.name)
var2 = tf.Variable(3.0)
print('var2:',var2.name)
var2 = tf.Variable(4.0)
print('var2:',var2.name)
get_var1 = tf.get_variable(name='firstvar',shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.3))
print('get_var1:',get_var1.name)
get_var1 = tf.get_variable(name='firstvar1',shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.4))
print('get_var1:',get_var1.name)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('var1=',var1.eval())
    print('var2=',var2.eval())
    print('get_var1=',get_var1.eval())
      

 結果如下:

tensorflow的variable、variable_scope和get_variable的用法和差別

我們來分析一下代碼,tf.Varibale是以定義的變量名稱為唯一辨別的,如var1,var2,是以可以重複地建立name='firstvar'的變量,但是tensorflow會給它們按順序取字尾,如firstvar_1:0,firstval_2:0,...,如果沒有制定名字,系統會自動加上一個名字Variable:0。而且由于tf.Varibale是以定義的變量名稱為唯一辨別的,是以當第二次命名同一個變量名時,第一個變量就會被覆寫,是以var1由1.0變成2.0。

對于tf.get_variable,它是以指定的name屬性為唯一辨別,而不是定義的變量名稱,是以不能同時定義兩個變量name是相同的,例如下面這種就會報錯:

1 get_var1 = tf.get_variable(name='a',shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.3))
2 print('get_var1:',get_var1.name)
3 get_var2 = tf.get_variable(name='a',shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.4))
4 print('get_var1:',get_var1.name)
      

  這樣就會報錯了。如果我們想聲明兩次相同name的變量,這時variable_scope就派上用場了,可以使用variable_scope将它們分開:

import tensorflow as tf
with tf.variable_scope('test1'):
    get_var1 = tf.get_variable(name='firstvar',shape=[2],dtype=tf.float32)
with tf.variable_scope('test2'):
    get_var2 = tf.get_variable(name='firstvar',shape=[2],dtype=tf.float32)
print('get_var1:',get_var1.name)
print('get_var2:',get_var2.name)
      

  這樣就不會報錯了,variable_scope相當于聲明了作用域,這樣在不同的作用域存在相同的變量就不會沖突了,結果如下:

tensorflow的variable、variable_scope和get_variable的用法和差別

 當然,scope還支援嵌套:

import tensorflow as tf
with tf.variable_scope('test1',):
    get_var1 = tf.get_variable(name='firstvar',shape=[2],dtype=tf.float32)
    with tf.variable_scope('test2',):
        get_var2 = tf.get_variable(name='firstvar',shape=[2],dtype=tf.float32)
print('get_var1:',get_var1.name)
print('get_var2:',get_var2.name)
      

  輸出結果為:

tensorflow的variable、variable_scope和get_variable的用法和差別

 怎麼樣?可以對照上面的結果體會一下不同。那麼如何通過get_variable來實作變量共享呢?這就要用到variable_scope裡的一個屬性:reuse,顧名思義嘛,當把reuse設定成True時就可以了,它表示使用已經定義過的變量,這是get_variable就不會再建立新的變量,而是去找與name相同的變量:

import tensorflow as tf
with tf.variable_scope('test1',):
    get_var1 = tf.get_variable(name='firstvar',shape=[2],dtype=tf.float32)
    with tf.variable_scope('test2',):
        get_var2 = tf.get_variable(name='firstvar',shape=[2],dtype=tf.float32)
print('get_var1:',get_var1.name)
print('get_var2:',get_var2.name)
with tf.variable_scope('test1',reuse=True):
    get_var3 = tf.get_variable(name='firstvar',shape=[2],dtype=tf.float32)
    with tf.variable_scope('test2',):
        get_var4 = tf.get_variable(name='firstvar',shape=[2],dtype=tf.float32)
print('get_var3:',get_var3.name)
print('get_var4:',get_var4.name)
      

  輸出結果如下:

tensorflow的variable、variable_scope和get_variable的用法和差別

 當然前面說過,reuse=True是使用前面已經建立過的變量,如果僅僅隻有從第八行到最後的代碼,也會報錯的,如果還是想這麼做,就需要把reuse屬性設定成tf.AUTO_REUSE

import tensorflow as tf
with tf.variable_scope('test1',reuse=tf.AUTO_REUSE):
    get_var3 = tf.get_variable(name='firstvar',shape=[2],dtype=tf.float32)
    with tf.variable_scope('test2',):
        get_var4 = tf.get_variable(name='firstvar',shape=[2],dtype=tf.float32)
print('get_var3:',get_var3.name)
print('get_var4:',get_var4.name)
      

  此時就不會報錯,tf.AUTO_REUSE可以實作第一次調用variable_scope時,傳入的reuse值為False,再次調用時,傳入reuse的值就會自動變為True。

多思考也是一種努力,做出正确的分析和選擇,因為我們的時間和精力都有限,是以把時間花在更有價值的地方。

繼續閱讀