天天看點

Tensorflow函數:tf.get_variable()

一. 函數的作用

該函數的主要作用是擷取已存在的變量(要求不僅名字,而且初始化方法等各個參數都一樣),若發現不存在則建立一個新變量;其可以采用各種初始化方法,不用明确指定值。

(與之相比的tf.Variable()則是每次均建立一個值)

二. 函數的參數說明

1. 函數的整體結構如下:

tf.get_variable(name, shape, dtype, initializer, regularizer, trainable, collections, caching_device, validate_shape)

2. 函數中的各個參數解釋如下:

  • name: 新變量或現有變量的名稱。
  • shape: 新變量或現有變量的形狀。
  • dtype: 新變量或現有變量的類型。
  • initializer: 可以了解為一個初始化器,如果建立了,則用它來初始化變量,預設為None,常見的初始化器如下:

         tf.random_normal_initializer(mean, stddev, seed, dtype)

         tf.truncated_normal_initializer(mean, stddev, seed, dtype)

         tf.random_uniform_initializer(minval, maxval, seed, dtype)

         tf.uniform_unit_scaling_initializer(factor, seed, dtype)

         tf.constant_initializer(value, dtype, name)

         tf.zeros_initializer(dtype)

         tf.ones_initializer(dtype)

  • regularizer: 指一個正則化對象,其可将于新建立的變量的結果添加到集合tf.GraphKeys.REGULARIZATION_LOSS中,并可用于正則化。
  • trainable: 若為‘True’,則該變量為可訓練變量,自動被加入GraphKeys.TRAINABLE_VARIABLES。
  • collections: 為一個集合清單的關鍵字,新變量将被添加到這個集合中,預設[GraphKeys.GLOBAL_VARIABLES]。
  • caching_device: 可選裝置字元串,描述應該緩存變量以供讀取的位置。
  • validate_shape: 預設為True,表示該變量的形狀不接受更改。

注:a. 如果initializer初始化方法是None(預設值),則會使用variable_scope()中定義的initializer;如果變量管理器中對應的參數也為None,則預設使用glorot_uniform_initializer;其也可以使用其他的tensor來初始化,進一步了解可參考部落格。

       b. 正則化方法對象regularizer預設是None,如果不指定,則會采用變量管理器variable_scope()中的正則化方式;如果變量管理器中對應的參數也為None,則不使用正則化,進一步了解可參考部落格。

       c. 可以通過tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)函數檢視參與正則化的變量

3. 函數的使用

1. 采用tf.get_variable()進行變量建立

#*******************************導入相關子產品***********************************#
import tensorflow as tf
import numpy as np
 
#*******************************聲明兩個變量***********************************#
x1 = tf.get_variable('x1', shape=[2,3], initializer=tf.random_normal_initializer(mean=0, stddev=0.1))

x2 = tf.get_variable('x2', shape=[1,3], initializer=tf.constant_initializer([4,5,6]))
 
#********************************建立會話*************************************#
with tf.Session() as sess:
    #-------------------變量進行初始化
    init_op = tf.global_variables_initializer()
    sess.run( init_op )
    #-------------------輸出變量及名稱
    print( sess.run(x1) )
    print( x1.op.name )
    print( sess.run(x2) )
    print( x2.op.name )
 
 
#--------------------模型的輸出
[[ 0.08277216 -0.14316109  0.03541737]
 [ 0.02363679 -0.19219622 -0.17002776]]

x1

[[4. 5. 6.]]

x2
           

2. 使用該函數的優點:

          a. 初始化更友善,比如用xavier_initializer()初始化器。

          b. 友善共享變量。因為tf.get_variable()會檢查目前命名空間下是否存在同樣name的變量,可以友善共享變量。而tf.Variable()每次都會建立一個變量。需要注意的是tf.get_variable()往往需要要配合reuse參數和tf.variable_scope()使用。

繼續閱讀