天天看點

tl儲存和讀取模型

輸入參數

1

其中save_list為所要儲存的參數,name為路徑和儲存的檔案名,傳入一個sess來執行此次操作。 

至于其他讀取方式的API于此類似,這裡不再贅述。

輸入參數 

<a href="http://tensorlayercn.readthedocs.io/zh/latest/modules/files.html#id11" target="_blank">TensorLayer中文版文檔-API-load_and_assign_npz</a>

其中sess和name意義與儲存(<code>save_npz</code>)相同,但是要注意的是<code>network</code>,到底該傳入什麼? 

查閱文檔可知

Parameters:  sess : TensorFlow Session  name : string Model path.  network : a <code>Layer</code> class The network to be assigned Returns: Returns False if faild to model is not exist.

顯而易見,應該傳入一個<code>Layer</code>類,然而如果隻是簡單初始化一個<code>Layer</code>類的變量傳進去,運作立刻就會報錯。

2

原因在于TensorLayer(也可以說是TensorFlow)所謂的儲存,隻是儲存模型的參數和變量的值,而不是模型本身。這一點和sklearn中的模型儲存是有差別的。 

是以隻有當讀取時的<code>Layer</code>(即模型)和儲存時的模型結構上一模一樣,才可以将儲存的模型參數一一對應,進而對模型指派。

舉例:  模型儲存時有3層,分别有800,500,400個節點,意味着有800+500+400個參數(不嚴謹,但可以這麼了解)。  然而在讀取的時候,也必須建構一個3層,分别有800,500,400個節點的模型才能将這些參數一一指派到對應的位置上去。  這兩個模型的唯一差別就在于前者模型的節點參數已經被指派,後者并沒有。

上述提到要建構一個和存儲前一模一樣模型骨架,首要的問題在于,描述一個模型骨架的完備的條件是什麼? 

1個網絡,3個方法(訓練,損失,精确函數)

3

4

5

6

7

8

9

10

接下來就可以根據自己的需求來設計模型的骨架了。 

要注意的是設計模型骨架時不需要使用<code>Session</code>來執行任何語句(因為暫時不需要訓練),同時模型骨架的接口,也就是資料輸入的模式<code>x_placeholder</code>需要傳入到模型骨架裡。 

代碼如下:

11

12

13

14

15

16

17

18

19

20

21

22

23

24

到此一個完整的新的模型骨架就定義在了<code>MyNetwork</code>這個對象中了。

定義好自己的模型骨架後就可以開始訓練了。這裡使用TL提供的<code>utils.fit</code>方法進行訓練。

模型存儲非常簡單,我們需要存儲模型<code>MyNetwork</code>中的<code>network</code>,是以隻需要導出<code>network</code>的所有參數并作為輸入即可。

存儲後在<code>model</code>檔案夾中會有一個<code>model5.npz</code>的檔案。這個就是模型的參數檔案了,裡面包含了模型所有參數資料。

上文提到要想讀取模型,首先要建構一個和被存儲模型一模一樣的骨架才能将所有參數一一指派給讀取模型的骨架。 

是以,要做的第一步就是複現存儲時的模型骨架。這部分代碼基本與構模組化型的基本一緻。唯一的差別在于,此時并不需要訓練。

到此模型的存儲和讀取介紹完畢。這些隻是TensorLayer下的存儲和讀取方式,但是對于Tensorflow其思路是一樣的。 

要記住的是,直接存儲模型太抽象,難以移植,而模型最重要的就是模型中每一個節點的權值與參數,是以隻要有了參數,就可以在任何平台上重新搭模組化型骨架,通過給每一個節點指派,複現存儲前的模型。

以下是本次教程的完整代碼: