天天看點

教程 | 一個基于TensorFlow的簡單故事生成案例:帶你了解LSTM

在深度學習中,循環神經網絡(rnn)是一系列善于從序列資料中學習的神經網絡。由于對長期依賴問題的魯棒性,長短期記憶(lstm)是一類已經有實際應用的循環神經網絡。現在已有大量關于 lstm 的文章和文獻,其中推薦如下兩篇:

goodfellow et.al.《深度學習》一書第十章:http://www.deeplearningbook.org/

chris olah:了解 lstm:http://colah.github.io/posts/2015-08-understanding-lstms/

已存在大量優秀的庫可以幫助你基于 lstm 建構機器學習應用。在 github 中,谷歌的 tensorflow 在此文成文時已有超過 50000 次星,表明了其在機器學習從業者中的流行度。

與此形成對比,相對缺乏的似乎是關于如何基于 lstm 建立易于了解的 tensorflow 應用的優秀文檔和示例,這也是本文嘗試解決的問題。

假設我們想用一個樣本短故事來訓練 lstm 預測下一個單詞,伊索寓言:

long ago , the mice had a general council to consider what measures they could take to outwit their common enemy , the cat . some said this , and some said that but at last a young mouse got up and said he had a proposal to make , which he thought would meet the case . you will all agree , said he , that our chief danger consists in the sly and treacherous manner in which the enemy approaches us . now , if we could receive some signal of her approach , we could easily escape from her . i venture , therefore , to propose that a small bell be procured , and attached by a ribbon round the neck of the cat . by this means we should always know when she was about , and could easily retire while she was in the neighbourhood . this proposal met with general applause , until an old mouse got up and said that is all very well , but who is to bell the cat ? the mice looked at one another and nobody spoke . then the old mouse said it is easy to propose impossible remedies .

listing 1.取自伊索寓言的短故事,其中有 112 個不同的符号。單詞和标點符号都視作符号。

如果我們将文本中的 3 個符号以正确的序列輸入 lstm,以 1 個标記了的符号作為輸出,最終神經網絡将學會正确地預測下一個符号(figure1)。

教程 | 一個基于TensorFlow的簡單故事生成案例:帶你了解LSTM

圖 1.有 3 個輸入和 1 個輸出的 lstm 單元

嚴格說來,lstm 隻能了解輸入的實數。一種将符号轉化為數字的方法是基于每個符号出現的頻率為其配置設定一個對應的整數。例如,上面的短文中有 112 個不同的符号。如清單 2 所示的函數建立了一個有如下條目 [「,」: 0 ] [「the」: 1 ], …, [「council」: 37 ],…,[「spoke」= 111 ] 的詞典。而為了解碼 lstm 的輸出,同時也生成了逆序字典。

listing 2.建立字典和逆序字典的函數

類似地,預測值也是一個唯一的整數值與逆序字典中預測符号的索引相對應。例如:如果預測值是 37,預測符号便是「council」。

輸出的生成看起來似乎簡單,但實際上 lstm 為下一個符号生成了一個含有 112 個元素的預測機率向量,并用 softmax() 函數歸一化。有着最高機率值的元素的索引便是逆序字典中預測符号的索引值(例如:一個 one-hot 向量)。圖 2 給出了這個過程。

教程 | 一個基于TensorFlow的簡單故事生成案例:帶你了解LSTM

圖 2.每一個輸入符号被配置設定給它的獨一無二的整數值所替代。輸出是一個表明了預測符号在反向詞典中索引的 one-hot 向量。

lstm 模型是這個應用的核心部分。令人驚訝的是,它很易于用 tensorflow 實作:

listing 3.有 512 個 lstm 單元的網絡模型

最難部分是以正确的格式和順序完成輸入。在這個例子中,lstm 的輸入是一個有 3 個整數的序列(例如:1x3 的整數向量)

網絡的常量、權值和偏差設定如下:

listing 4.常量和訓練參數

訓練過程中的每一步,3 個符号都在訓練資料中被檢索。然後 3 個符号轉化為整數以形成輸入向量。

listing 5.将符号轉化為整數向量作為輸入

訓練标簽是一個位于 3 個輸入符号之後的 one-hot 向量。

listing 6.單向量作為标簽

在轉化為輸入詞典的格式後,進行如下的優化過程:

listing 7.訓練過程中的優化

精度和損失被累積以監測訓練過程。通常 50,000 次疊代足以達到可接受的精度要求。

listing 8.一個訓練間隔的預測和精度資料示例(間隔 1000 步)

代價是标簽和 softmax() 預測之間的交叉熵,它被 rmsprop 以 0.001 的學習率進行優化。在本文示例的情況中,rmsprop 通常比 adam 和 sgd 表現得更好。

listing 9.損失和優化器

lstm 的精度可以通過增加層來改善。

listing 10. 改善的 lstm

現在,到了有意思的部分。讓我們通過将預測得到的輸出作為輸入中的下一個符号輸入 lstm 來生成一個故事吧。示例輸入是「had a general」,lstm 給出了正确的輸出預測「council」。然後「council」作為新的輸入「a general council」的一部分輸入神經網絡得到下一個輸出「to」,如此循環下去。令人驚訝的是,lstm 創作出了一個有一定含義的故事。

listing 11.截取了樣本故事生成的故事中的前 32 個預測值

如果我們輸入另一個序列(例如:「mouse」,「mouse」,「mouse」)但并不一定是這個故事中的序列,那麼會自動生成另一個故事。

listing 12.并非來源于示例故事中的輸入序列

示例代碼可以在這裡找到:https://github.com/roatienza/deep-learning-experiments/blob/master/experiments/tensorflow/rnn/rnn_words.py

示例文本的連結在這裡:https://github.com/roatienza/deep-learning-experiments/blob/master/experiments/tensorflow/rnn/belling_the_cat.txt

小貼士:

1. 用整數值編碼符号容易操作但會丢失單詞的意思。本文中将符号轉化為整數值是用來簡化關于用 tensorflow 建立 lstm 應用的讨論的。更推薦采用 word2vec 将符号編碼為向量。

2. 将輸出表達成單向量是效率較低的方式,尤其當我們有一個現實的單詞量大小時。牛津詞典有超過 170,000 個單詞,而上面的例子中隻有 112 個單詞。再次聲明,本文中的示例隻為了簡化讨論。

3. 這裡采用的代碼受到了 tensorflow-examples 的啟發:https://github.com/aymericdamien/tensorflow-examples/blob/master/examples/3_neuralnetworks/recurrent_network.py

4. 本文例子中的輸入大小為 3,看一看當采用其它大小的輸入時會發生什麼吧(例如:4,5 或更多)。

5. 每次運作代碼都可能生成不同的結果,lstm 的預測能力也會不同。這是由于精度依賴于初始參數的随機設定。訓練次數越多(超過 150,000 次)精度也會相應提高。每次運作代碼,建立的詞典也會不同

6. tensorboard 在調試中,尤其當檢查代碼是否正确地建立了圖時很有用。

7. 試着用另一個故事測試 lstm,尤其是用另一種語言寫的故事。

原文連結:https://medium.com/towards-data-science/lstm-by-example-using-tensorflow-feb0c1968537

本文來源于"中國人工智能學會",原文發表時間"2017-04-25 "

繼續閱讀