天天看點

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

原文位址

本篇部落格轉載自知乎,原文見上面的連結。本文主要利用圖檔的形式簡單的介紹了經典RNN、RNN幾個重要變體,以及Seq2Seq模型、Attention機制。旨在讓大家有個初步的印象,之後的幾篇部落格還會更詳細的展開。

目錄

一、單層網絡

二、經典的RNN結構(N vs. N)

三、N vs. 1

四、1 vs. N

五、 N vs. M

六、Attention機制

七、總結

一、單層網絡

在RNN之前,首先了解一下最基本的單層全聯接網絡,結構如下:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

其中輸入是x,經過線性變換

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

和非線性變換激活函數f得到輸出y。

二、經典的RNN結構(N vs. N)

在實際應用中,我們還會遇到很多序列型資料:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介
  • 在自然語言處理問題中。
    自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介
    可以看作是第一個單詞或其嵌入表示,
    自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介
    可以看作是第二個,依次類推。
  • 語音進行中。此時,
    自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介
    自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介
    自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介
    ...是每幀的聲音信号
  • 時間序列問題。例如每天的股票價格等。

序列形資料不太好用原始的全聯接網絡來處理,因為序列中各個term之間是有關聯的,如一句話中的每個單詞都是有關系的,而不是獨立的,而且全聯接網絡一旦搭建起來,輸入是确定的,隻能處理定長的資料,而序列一般是長短不一的。RNN可以處理序列中每個term之間的聯系,而且可以處理變長序列(不過我們用RNN處理序列時,一般把所有的序列統一為一個定長,可以向量化同時對多個序列進行處理,加快速度)。

為了模組化序列問題,RNN引入隐藏狀态h(hidden state)的概念,h可以對序列型資料提取特征,接着再轉換為輸出。

先從

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

(第一個時間步驟上的隐藏狀态)開始看:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

其中,圓圈和方塊代表的是向量/tensor;一個箭頭代表對向量做一次變換。如上圖中

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

分别有一個箭頭連結,表示對

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

各做一次變換。

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

的計算和

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

類似。注意計算時,每一個時間步驟上使用的參數U,W,b都是一樣的,也就是每個時間步驟上的參數都是共享的,這是RNN的重要特點。

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介
自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

是第2個時間步驟上的隐藏狀态,它編碼了序列中前兩個term的資訊。

使用相同的參數U、W、b,依次計算剩下的:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

友善起見,隻畫了4個時間步驟上的操,即序列長度為4的情況。實際上這個計算過程可以無限持續下去,直到輸入序列的最後一個term。

目前我們的RNN還沒有輸出,得到輸出值的方法就是直接通過隐藏狀态h進行計算:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

和之前一樣,一個箭頭就表示對對應的向量/tensor做一次類似于f(Wx+b)的變換,這裡的箭頭就表示對

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

進行一次變換,得到輸出

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

(第一個時間步驟上的輸出)。

剩下的每個時間步驟上的輸出類似進行(使用和計算

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

時,相同的參數V和c):

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

上圖就是最經典的RNN結構,他的輸入是一個序列

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

,每個時間步驟上輸入序列中的一個term 

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

,輸出是一個等長的序列

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

,每個時間步驟上産生一個輸出

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

。注意上述結構下的輸入/輸出序列是等長的。

由于這個限制存在,經典RNN的使用範圍比較小,但也有一些問題适合用經典的RNN搭建:

  • 計算視訊中每一幀的分類标簽。因為要對每一幀進行計算,是以輸入和輸出序列等長。
  • 輸入為一個字元,輸出為預測下一個字元出現的機率。這就是char RNN,他可以用來生成文章、詩歌甚至代碼。可以參考The Unreasonable Effectiveness of Recurrent Neural Networks。

三、N vs. 1

有時,我們要處理的問題輸入是一個序列,輸出是一個單獨的值而不是一個序列,此時隻需要對最後一個時間步驟的隐藏狀态h進行輸出變換即可:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

這種結構通常用于處理序列分類問題。如輸入一段文本判斷它所屬的類别,或判斷其情感傾向,輸入一段視訊判斷他的類别等。

四、1 vs. N

下面看一下,如果輸入不是序列而輸出是序列的情況如何處理,可以隻在序列開始進行輸入計算:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

還有一種結構是把輸入資訊x作為每個時間步驟上的輸入:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

下面是一個等價表示:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

這種1 vs. N的結構可以處理的問題:

  • 從圖像生成文字(image caption)/生成圖像描述:此時輸入的X就是編碼/提取的圖像特征,而輸出的y序列就是可以描述該圖像内容的一段句子。
  • 從類别生成語音或音樂等

五、 N vs. M

下面介紹RNN最重要的一個變種:N vs. M。 這種結構又叫Encoder-Decoder模型,也可以稱為Seq2Seq模型。

原始的RNN要求序列等長,即N vs. N,然而我們遇到的大部分問題序列都是不等長的,如機器翻譯、語音識别等,源序列和目标序列往往沒有相同的長度。

為此,Encoder-Decoder結構先将輸入序列編碼為一個上下文向量c,用它來表示輸入序列的語義特征:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

得到c的方式有多種,最簡單的就是把Encoder的最後一個時間步驟的隐藏狀态指派給c,還可以對最後的隐藏狀态做一個變換得到c,也可以對所有時間步驟上隐藏狀态最一個變換得到c。

得到c之後,就要用另一個RNN網絡對其進行解碼,這部分RNN網絡稱為Decoder。具體做法是把Encoder産生的c當作初始狀态

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

輸入到Decoder中:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

還有一種做法是将c作為每一個時間步驟的輸入:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

由于這種Encoder-Decoder結構不限制輸入和輸出序列的長度,是以應用範圍非常廣泛,比如:

  • 機器翻譯:Encoder-Decoder最經典的應用,事實上這一結構就是在機器翻譯領域最先提出。
  • 文本摘要:輸入是一段文本序列,輸出是這段文本序列的摘要序列。
  • 閱讀了解:将輸入文章和問題分别編碼,再對其進行解碼得到問題的答案。
  • 語音識别:輸入是語音序列,輸出是文字序列
  • ............

六、Attention機制

在Encoder-Decoder結構中,Encoder把所有輸入序列都編碼成一個統一的語義特征c再解碼,是以c必須包含原始序列中所有資訊,他的長度就成了限制模型性能的瓶頸。如機器翻譯問題,當要翻譯的句子較長時,一個c可能存不下那麼多資訊,就會造成翻譯靜笃的下降。

Attention機制通過在Decoder的每一個時間步驟上輸入不同的c來解決這個問題,下圖是帶有Attention機制的Decoder:

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

每一個c會自動去選取與目前時間步驟所要輸出的y最合适的上下文資訊。具體來說,我們用

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

衡量Encoder中第j階段的隐藏狀态

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

和Decoder中第i階段的相關性,最終Decoder中第i階段的輸入的上下文資訊

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

,就來自于Encoder中所有

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

的權重和。

以機器翻譯為例(中->英):

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

輸入序列是"我愛中國",是以Encoder中各個時間步驟/階段的隐藏狀态

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

分别可以大緻看作是"我","愛","中","國"所代表的資訊(嚴格來說,

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

應該包含了"我愛"的資訊,因為他不僅包含目前時間步驟的輸入,也包含上一個時間步驟隐藏狀态的輸入,但這裡我們看作它主要包含"愛"的資訊,其他h也是一樣)。在翻譯成英語時,第一個上下文

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

應該和"我"這個字最相關,是以對應的

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

就比較大,而相應的

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

就比較小。

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

應該和"愛"最相關,是以對應的

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

就比較大。最後

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

和“中國”或

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

最相關,是以

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

的值就比較大。

至此,關于Attention機制,最後一個問題是這些權重

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

是怎麼來的?

事實上,

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

同樣是從模型/資料中學出來的,他實際和Decoder的第i-1階段的隐藏狀态和Encoder第j個階段的隐藏狀态有關,利用他們經過特殊的運算得到(之後的部落格我們會詳細介紹運算細節)。

同樣拿上面的機器翻譯舉例,

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

的計算(箭頭表示對

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

同時作變換):

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介
自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

的計算(箭頭表示對

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

同時作變換)

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介
自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

的計算(箭頭表示對

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

同時作變換)

自然語言處理 | (24) RNN、RNN變體、Seq2Seq、Attention機制簡介

以上就是帶有Attention的Encoder-Decoder模型計算的全過程。

七、總結

本文主要講了N vs N,N vs 1、1 vs N、N vs M四種經典的RNN模型,以及如何使用Attention結構。希望能對大家有所幫助。

上述RNN采用的都是樸素RNN單元,當然也可以換成LSTM單元或GRU單元,他們更擅長捕捉長依賴關系,内部的運算會更加複雜,但基本原理是一樣的,之後我們還會介紹。可以參考這篇文章Understanding LSTM Networks來了解一下LSTM的内部結構。

上述提到的RNN結構都是最基本的結構,相當于積木,在實際上使用時,會産生各種各樣的變體,如雙向RNN,将RNN單元替換為LSTM或GRU,堆疊多層RNN,添加dropout或batch normaliaztion防止過拟合,加速訓練以及transformeri 結構等。但是我們隻有把這些基礎結構學好,學會每塊積木的用途和原理,接下裡就可以靈活的使用這些積木,基于自己的應用,搭建出各種複雜的網絡結構了。

繼續閱讀