天天看點

【轉載】TensorFlow學習筆記:共享變量

原文連結:http://jermmy.xyz/2017/08/25/2017-8-25-learn-tensorflow-shared-variables/

本文是根據 TensorFlow 官方教程翻譯總結的學習筆記,主要介紹了在 TensorFlow 中如何共享參數變量。

教程中首先引入共享變量的應用場景,緊接着用一個例子介紹如何實作共享變量(主要涉及到 <code>tf.variable_scope()</code>和<code>tf.get_variable()</code>兩個接口),最後會介紹變量域 (Variable Scope) 的工作方式。

假設我們建立了一個簡單的 CNN 網絡:

這個網絡中用 <code>tf.Variable()</code> 初始化了四個參數。

不過,别看我們用一個函數封裝好了網絡,當我們要調用網絡進行訓練時,問題就會變得麻煩。比如說,我們有 <code>image1</code> 和 <code>image2</code> 兩張圖檔,如果将它們同時丢到網絡裡面,由于參數是在函數裡面定義的,這樣一來,每調用一次函數,就相當于又初始化一次變量:

當然了,我們很快也能找到解決辦法,那就是把參數的初始化放在函數外面,把它們當作全局變量,這樣一來,就相當于全局「共享」了嘛。比如說,我們可以用一個 <code>dict</code> 在函數外定義參數:

為此,TensorFlow 内置了變量域這個功能,讓我們可以通過域名來區分或共享變量。通過它,我們完全可以将參數放在函數内部執行個體化,再也不用手動儲存一份很長的參數清單了。不過,這種方法對于熟悉面向對象的你來說,會不會有點别扭呢?因為它完全破壞了原有的封裝。也許你會說,不礙事的,隻要将參數和<code>filter</code>函數都放到一個類裡即可。不錯,面向對象的方法保持了原有的封裝,但這裡出現了另一個問題:當網絡變得很複雜很龐大時,你的參數清單/字典也會變得很冗長,而且如果你将網絡分割成幾個不同的函數來實作,那麼,在傳參時将變得很麻煩,而且一旦出現一點點錯誤,就可能導緻巨大的 bug。

這裡主要包括兩個函數接口:

<code>tf.get_variable(</code> :根據指定的變量名執行個體化或傳回一個 <code>tensor</code>對象;

<code>tf.variable_scope(</code>:管理 <code>tf.get_variable()</code> 變量的域名。

<code>tf.get_variable()</code> 的機制跟 <code>tf.Variable()</code> 有很大不同,如果指定的變量名已經存在(即先前已經用同一個變量名通過 <code>get_variable()</code> 函數執行個體化了變量),那麼 <code>get_variable()</code>隻會傳回之前的變量,否則才創造新的變量。

現在,我們用 <code>tf.get_variable()</code> 來解決上面提到的問題。我們将卷積網絡的兩個參數變量分别命名為 <code>weights</code> 和 <code>biases</code>。不過,由于總共有 4 個參數,如果還要再手動加個 <code>weights1</code> 、<code>weights2</code> ,那代碼又要開始惡心了。于是,TensorFlow 加入變量域的機制來幫助我們區分變量,比如:

不過,如果直接這樣調用 <code>my_image_filter</code>,是會抛異常的:我們先定義一個 <code>conv_relu()</code> 函數,因為 conv 和 relu 都是很常用的操作,也許很多層都會用到,是以單獨将這兩個操作提取出來。然後在 <code>my_image_filter()</code> 函數中真正定義我們的網絡模型。注意到,我們用 <code>tf.variable_scope()</code> 來分别處理兩個卷積層的參數。正如注釋中提到的那樣,這個函數會在内部的變量名前面再加上一個「scope」字首,比如:<code>conv1/weights</code>表示第一個卷積層的權值參數。這樣一來,我們就可以通過域名來區分各個層之間的參數了。

因為 <code>tf.get_variable()</code>雖然可以共享變量,但預設上它隻是檢查變量名,防止重複。要開啟變量共享,你還必須指定在哪個域名内可以共用變量:

到這一步,共享變量的工作就完成了。你甚至都不用在函數外定義變量,直接調用同一個函數并傳入不同的域名,就可以讓 TensorFlow 來幫你管理變量了。

==================== UPDATE 2018.3.8 ======================

官方的教程都是一些簡單的例子,但在實際開發中,情況可能會複雜得多。比如,有一個網絡,它的前半部分是要共享的,而後半部分則是不需要共享的,在這種情況下,如果還要自己去調用 <code>scope.reuse_variables()</code> 來決定共享的時機,無論如何都是辦不到的,比如下面這個例子:

這個例子中,我們要使用兩個變量: <code>w</code> 和 <code>u</code>,其中 <code>w</code> 是不共享的,而 <code>u</code> 是共享的。在這種情況下,不管你加不加 <code>scope.reuse_variables()</code>,代碼都會出錯。是以,Tensorflow 提供另一種開啟共享的方法:

這裡隻是加了一個參數 <code>reuse=tf.AUTO_REUSE</code>,但正如名字所示,這是一種自動共享的機制,當系統檢測到我們用了一個之前已經定義的變量時,就開啟共享,否則就重新建立變量。這幾乎是「萬金油」式的寫法????。

接下來我們再仔細梳理一下這背後發生的事情。

我們要先搞清楚,當我們調用 <code>tf.get_variable(name, shape, dtype, initializer)</code> 時,這背後到底做了什麼。

首先,TensorFlow 會判斷是否要共享變量,也就是判斷 <code>tf.get_variable_scope().reuse</code> 的值,如果結果為 <code>False</code>(即你沒有在變量域内調用<code>scope.reuse_variables()</code>),那麼 TensorFlow 認為你是要初始化一個新的變量,緊接着它會判斷這個命名的變量是否存在。如果存在,會抛出 <code>ValueError</code> 異常,否則,就根據 <code>initializer</code> 初始化變量:

而如果 <code>tf.get_variable_scope().reuse == True</code>,那麼 TensorFlow 會執行相反的動作,就是到程式裡面尋找變量名為 <code>scope name + name</code> 的變量,如果變量不存在,會抛出 <code>ValueError</code> 異常,否則,就傳回找到的變量:

了解變量域背後的工作方式後,我們就可以進一步熟悉其他一些技巧了。

變量域可以嵌套使用:

我們也可以通過 <code>tf.get_variable_scope()</code> 來獲得目前的變量域對象,并通過 <code>reuse_variables()</code> 方法來設定是否共享變量。不過,TensorFlow 并不支援将 <code>reuse</code> 值設為 <code>False</code>,如果你要停止共享變量,可以選擇離開目前所在的變量域,或者再進入一個新的變量域(比如,再進入一個 <code>with</code> 語句,然後指定新的域名)。

還需注意的一點是,一旦在一個變量域内将 <code>reuse</code> 設為 <code>True</code>,那麼這個變量域的子變量域也會繼承這個 <code>reuse</code> 值,自動開啟共享變量:

如果一直用字元串來區分變量域,寫起來容易出錯。為此,TensorFlow 提供了一個變量域對象來幫助我們管理代碼:

記住,用這個變量域對象還可以讓我們跳出目前所在的變量域區域:

每次初始化變量時都要傳入一個 <code>initializer</code>,這實在是麻煩,而如果使用變量域的話,就可以批量初始化參數了:

TensorFlow官方教程

繼續閱讀