天天看點

從原理到實戰 英偉達教你用PyTorch搭建RNN(下)

建立模型時,spinn.__init__被調用一次。它配置設定、初始化參數,但不進行任何神經網絡運算,也不涉及建立計算圖。每組新資料 batch 上運作的代碼,在 spinn 中定義。pytorch 裡,使用者定義模型前饋通道的方法名為 “forward”。事實上,它是對上文提到的 stack-manipulation 算法的實作,在普通 python 裡,它運作于 buffer 和堆棧的 batch 上——對每個樣例使用兩者之一。 在轉換過程包含的“shift” 和 “reduce” op 上疊代,如果它存在,就運作 tracker,并運作于 batch 中的每個樣例以應用 “shift”op,或加入需要 “reduce” op 的樣例清單。然後在清單所有的樣例上運作 reduce 層,把結果 push 回相關堆棧。

調用 self.tracker 或 self.reduce,會相對應地運作 tracker 中的“forward”方式,或 reduce 子子產品。這需要在一個樣例清單來執行該 op。所有數學運算密集、用 gpu 加速、收益用 batch 的 op 都發生在  tracker 和 reduce 之中。是以,在主要的“forward”方式中,單獨在不同樣例上運作;對 batch 中的每個樣例保持獨立的 buffer 和堆棧,都是意義的。為了更幹淨地寫這些函數,我會用一些輔助,把這些樣例清單轉為 batch 化的張量,反之亦然。

我傾向于讓 reduce 子產品自動 batch 參數來加速計算,然後 unbatch 它們,這樣之後能單獨地 push、pop。把每一組左右子短語放到一起,來表示母短語的合成函數是 treelstm,一個正常 lstm 的變種。此合成函數要求,所有子樹的狀态要由兩個張量組成,一個隐藏狀态 h 和一個記憶體單元狀态 c。定義該函數的因素有兩個:運作于子樹隐藏狀态中的兩個線性層  (nn.linear),以及非線性合成函數 tree_lstm,後者把線性層的結果和子樹記憶體單元的狀态組合起來。在 spinn 中,這通過加入第三個運作于 tracker 隐藏狀态的 線性層來拓展。

從原理到實戰 英偉達教你用PyTorch搭建RNN(下)

由于 reduce 層和以與之類似方式執行的 tracker 都在 lstm 上運作,batch 和 unbatch 輔助函數會在成對隐藏、記憶體狀态上運作。

從原理到實戰 英偉達教你用PyTorch搭建RNN(下)

上文描述的、該模型不含  tracker 的版本,其實特别适合 tensorflow 的 tf.fold,針對動态計算圖特殊情形的 tensorflow 新專用語言。包含 tracker 的版本實作起來要難得多。這背後的原因是:加入  tracker,就意味着從 recursive 模式切換為基于堆棧的模式。在上面的代碼裡,這以最直覺的形式表現了出來,這使用的是取決于輸入值的 conditional branches。 fold 并沒有内建的 conditional branch op,是以模型裡的圖結構隻取決于輸入的結構而非值。另外,建立一個由 tracker 決定如何解析輸入語句的 spinn 實際上是不可能的。這是因為 fold 裡的圖結構——雖然它們取決于輸入樣例的結構,在一個輸入樣例加載之後,它必須完全固定下來。

deepmind 和谷歌大腦的研究人員正在摸索一個類似的模型。他們用強化學習來訓練一個 spinn 的 tracker,來解析輸入語句,而不需要任何外部解析資料。本質上,這樣的模型以随機的猜想開始,當它的解析在整體分類任務上生成較好精度時,獎勵它自己,以此來學習。研究人員們寫道,他們“使用 batch size 1,因為取決于 policy network [tracker] 的樣本, 對于每個樣例,計算圖需要在每次疊代後重建。”但即便在像本文這麼複雜、結構有随機變化特性的神經網絡上,在 pytorch 上,研究人員們也能隻用 batch 訓練。

pytorch 還是第一個在算法庫内置了強化學習的架構,即它的 stochastic computation graphs (随機計算圖)。這使得 policy gradient 強化學習像反向傳播一樣易于使用。若想要把它加入上面描述的模型,你隻需要像重寫主 spinn 的頭幾行代碼,生成下面一樣的循環,讓 tracker 來定義做任何一種解析器(parser)轉換的機率。

當 batch 一路運作下來,模型知道了它的類别預測精确程度之後,我可以在反向傳播之外,用傳統方式通過圖的其餘部分把獎勵信号傳回這些随機計算圖節點:

谷歌研究人員從 spinn+增強學習報告的結果,比在 snli 獲得的原始 spinn 要好一點,雖然它的增強學習版并沒有預計算文法樹。深度增強學習在 nlp 的應用是一個全新的領域,其中的研究問題十分廣泛。通過把增強學習整合到架構裡,pytorch 極大降低了使用門檻。

====================================分割線================================

本文作者:三川

繼續閱讀