天天看點

如何用TensorFlow建構RNN?這裡有一份極簡的教程什麼是RNN?建構網絡生成資料建構計算圖變量和占位符拆分序列前向傳播計算Loss可視化結果建立訓練會話整個程式

本文來自AI新媒體量子位(QbitAI)

本文作者Erik Hallström是一名深度學習研究工程師,他的這份教程以Echo-RNN為例,介紹了如何在TensorFlow環境中建構一個簡單的循環神經網絡。

RNN是循環神經網絡(Recurrent Neural Network)的英文縮寫,它能結合資料點之間的特定順序和幅值大小等多個特征,來處理序列資料。更重要的是,這種網絡的輸入序列可以是任意長度的。

舉一個簡單的例子:數字時間序列,具體任務是根據先前值來預測後續值。在每個時間步中,循環神經網絡的輸入是目前值,以及一個表征該網絡在之前的時間步中已經獲得資訊的狀态向量。該狀态向量是RNN網絡的編碼記憶單元,在訓練網絡之前初始化為零向量。

如何用TensorFlow建構RNN?這裡有一份極簡的教程什麼是RNN?建構網絡生成資料建構計算圖變量和占位符拆分序列前向傳播計算Loss可視化結果建立訓練會話整個程式
如何用TensorFlow建構RNN?這裡有一份極簡的教程什麼是RNN?建構網絡生成資料建構計算圖變量和占位符拆分序列前向傳播計算Loss可視化結果建立訓練會話整個程式

圖1:RNN處理序列資料的步驟示意圖。

本文隻對RNN做簡要介紹,主要專注于實踐:如何建構RNN網絡。如果有網絡結構相關的疑惑,建議多看看說明性文章。

在了解RNN網絡的基本知識後,就很容易了解以下内容。

我們先建立一個簡單的回聲狀态網絡(Echo-RNN)。這種網絡能記憶輸入資料資訊,在若幹時間步後将其回傳。我們先設定若幹個網絡常數,讀完文章你就能明白它們的作用。

現在生成随機的訓練資料,輸入為一個随機的二進制向量,在<code>echo_step</code>個時間步後,可得到輸入的“回聲”,即輸出。

包含<code>batch_size</code>的兩行代碼,将資料重構為新矩陣。神經網絡的訓練,需要利用小批次資料(mini-batch),來近似得到關于神經元權重的損失函數梯度。在訓練過程中,随機批次操作能防止過拟合和降低硬體壓力。整個資料集通過資料重構轉化為一個矩陣,并将其分解為多個小批次資料。

如何用TensorFlow建構RNN?這裡有一份極簡的教程什麼是RNN?建構網絡生成資料建構計算圖變量和占位符拆分序列前向傳播計算Loss可視化結果建立訓練會話整個程式

圖2:重構資料矩陣的示意圖,箭頭曲線訓示了在不同行上的相鄰時間步。淺灰色矩形代表“0”,深灰色矩形代表“1”。

首先在TensorFlow中建立一個計算圖,指定将要執行的運算。該計算圖的輸入和輸出通常是多元數組,也被稱為張量(tensor)。我們可以利用CPU、GPU和遠端伺服器的計算資源,在會話中疊代執行該計算圖。

本文所用的基本TensorFlow資料結構是變量和占位符。占位符是計算圖的“起始節點”。在運作每個計算圖時,批處理資料被傳遞到占位符中。另外,RNN狀态向量也是存儲在占位符中,在每一次運作後更新輸出。

網絡的權重和偏差作為TensorFlow的變量,在運作時保持不變,并在輸入批資料後進行逐漸更新。

下圖表示了輸入資料矩陣,以及虛線視窗指出了占位符的目前位置。在每次運作時,這個“批處理視窗”根據箭頭訓示方向,以定義好的長度從左邊滑到右邊。在示意圖中,<code>batch_size</code>(批資料數量)為3,<code>truncated_backprop_length</code>(截斷反傳長度)為3,<code>total_series_length</code>(全局長度)為36。這些參數是用來示意的,與實際代碼中定義的值不一樣。在示意圖中序列各點也以數字标出。

如何用TensorFlow建構RNN?這裡有一份極簡的教程什麼是RNN?建構網絡生成資料建構計算圖變量和占位符拆分序列前向傳播計算Loss可視化結果建立訓練會話整個程式

圖3:訓練資料的示意圖,用虛線矩形訓示目前批資料,用數字标明了序列順序。

現在開始建構RNN計算圖的下個部分,首先我們要以相鄰的時間步分割批資料。

如下圖所示,可以按批次分解各列,轉成list格式檔案。RNN會同時從不同位置開始訓練時間序列:在示例中分别從4到6、從16到18和從28到30。用<code>plural</code>和<code>series</code>做變量名,是為了強調該變量為list檔案,用來在每一步中表示具有多個位置的時間序列。

如何用TensorFlow建構RNN?這裡有一份極簡的教程什麼是RNN?建構網絡生成資料建構計算圖變量和占位符拆分序列前向傳播計算Loss可視化結果建立訓練會話整個程式

圖4:将資料拆分為多列的原理圖,用數字标出序列順序,箭頭表示相鄰的時間步。

在我們的時間序列資料中,在三個位置同時開啟訓練,是以在前向傳播時需要儲存三個狀态。我們在參數定義時就已經考慮到這一點了,故将init_state設定為3。

接下來,我們繼續建構計算圖中執行RNN計算功能的子產品。

在這段代碼中,我們通過計算<code>current_input</code> <code>Wa</code> + <code>current_state</code> <code>Wbin</code>,得到兩個仿射變換的總和<code>input_and_state_concatenated</code>。在連接配接這兩個張量後,隻用了一個矩陣乘法即可在每個批次中添加所有樣本的偏置<code>b</code>。

如何用TensorFlow建構RNN?這裡有一份極簡的教程什麼是RNN?建構網絡生成資料建構計算圖變量和占位符拆分序列前向傳播計算Loss可視化結果建立訓練會話整個程式

圖5:第8行代碼的矩陣計算示意圖,省略了非線性變換arctan。

你可能會想知道變量<code>truncated_backprop_lengthis</code>的作用。在訓練時,RNN被看做是一種在每一層都有備援權重的深層神經網絡。在訓練開始時,這些層由于展開後占據了太多的計算資源,是以要在有限的時間步内截斷。在每個批次訓練時,網絡誤差反向傳播了三次。

這是計算圖的最後一部分,我們建立了一個從狀态到輸出的全連接配接層,用于softmax分類,标簽采用One-hot編碼,用于計算每個批次的Loss。

最後一行是添加訓練函數,TensorFlow将自動執行反向傳播函數:對每批資料執行一次計算圖,并逐漸更新網絡權重。

這裡調用的<code>tosparse_softmax_cross_entropy_with_logits</code>函數,能在内部算得softmax函數值後,繼續計算交叉熵。在示例中,各類是互斥的,非0即1,這也是将要采用稀疏自編碼的原因。标簽的格式為<code>[batch_size,num_classes]</code>。

我們利用可視化功能tensorboard,在訓練過程中觀察網絡訓練情況。它将會在時間次元上繪制Loss值,顯示在訓練批次中資料輸入、資料輸出和網絡結構對不同樣本的實時預測效果。

已經完成建構網絡的工作,開始訓練網絡。在TensorFlow中,該計算圖會在一個會話中執行。在每一步開始時,都會随機生成新的資料。

從第15-19行可以看出,在每次疊代中往前移動<code>truncated_backprop_length</code>步,但可能有不同的stride值。這樣做的缺點是,為了封裝相關的訓練資料,<code>truncated_backprop_length</code>的值要顯著大于時間依賴值(本文中為3步),否則可能會丢失很多有效資訊,如圖6所示。

如何用TensorFlow建構RNN?這裡有一份極簡的教程什麼是RNN?建構網絡生成資料建構計算圖變量和占位符拆分序列前向傳播計算Loss可視化結果建立訓練會話整個程式

圖6:資料示意圖

我們用多個正方形來代表時間序列,上升的黑色方塊表示回波輸出,由輸入回波(黑色方塊)經過三次激活後得到。滑動批處理視窗在每次運作時也滑動了三次,在示例中之前沒有任何批資料,用來封裝依賴關系,是以它不能進行訓練。

請注意,本文隻是用一個簡單示例解釋了RNN如何工作,可以輕松地用幾行代碼中來實作此網絡。此網絡将能夠準确地了解回聲行為,是以不需要任何測試資料。

在訓練過程中,該程式實時更新圖表,如圖7所示。藍色條表示用于訓練的輸入信号,紅色條表示訓練得到的輸出回波,綠色條是RNN網絡産生的預測回波。不同的條形圖顯示了在目前批次中多個批資料的預測回波。

我們的算法能很快地完成訓練任務。左上角的圖表輸出了損失函數,但為什麼曲線上有尖峰?答案就在下面。

如何用TensorFlow建構RNN?這裡有一份極簡的教程什麼是RNN?建構網絡生成資料建構計算圖變量和占位符拆分序列前向傳播計算Loss可視化結果建立訓練會話整個程式

圖7:各圖分别為Loss,訓練的輸入和輸出資料(藍色和紅色)以及預測回波(綠色)。

尖峰的産生原因是在新的疊代開始時,會産生新的資料。由于矩陣重構,每行上的第一個元素與上一行中的最後一個元素會相鄰。但是所有行中的前幾個元素(第一個除外)都具有不包含在該狀态中的依賴關系,是以在最開始的批進行中,網絡的預測功能不良。

這是完整實作RNN網絡的程式,隻需複制粘貼即可運作。如果對文章有什麼疑問,歡迎加量子位小助手qbitbot,注明“加入門群”并做個自我介紹,小助手将帶你和更多小夥伴交流讨論。

本文作者:王小新

原文釋出時間:2017-04-29

繼續閱讀