Attention機制與Transformer模型,以及基于Transformer模型的預訓練模型BERT的出現,對NLP領域産生了變革性提升。現在在大型NLP任務、比賽中,基本很少能見到RNN的影子了。大部分是BERT(或是其各種變體,或者還加上TextCNN)做特征提取(feature extraction)或是微調(fine-tune),結合最後的全連接配接層+softmax/sigmoid,完成各類NLP任務。
這章主要介紹Transformer模型的基礎:Attention機制。在介紹Attention機制前,有必要介紹一下Seq2Seq(Sequence to Sequence)模型,它是這一切的基礎。
Seq2Seq(Sequence to Sequence)模型是一個多對多的模型,最早由Bengio于2014年的論文[1]提出。當時仍是RNN的時代,還未出現Attention。是以我們在介紹時Seq2Seq時也會以RNN為基礎。SeqSeq最常見的應用是在機器翻譯,例如輸入一句英文,翻譯出一句德文,這兩者的輸入輸出長度都不是固定的,而Seq2Seq便适用于這種輸入輸出長度不固定的場景。
Seq2Seq模型包含2個部分:Encoder和Decoder。以機器翻譯(英語翻譯為德語)為例,Seq2Seq模型如下圖所示:

Fig. 1. ShusenWang. Neural Machine Translation
左邊部分稱為Encoder(編碼器),可以是任意RNN,例如上一章介紹過的LSTM。假設有一序列長度為m的英語句子,輸入到Encoder(例如LSTM),在經過時間步t後,輸出最終狀态state hm(或者同時也輸出LSTM的傳輸帶向量Cm)。此時hm便包含了這整句英語句子的資訊。
右邊部分稱為Decoder(解碼器),也可以是任意RNN,例如LSTM。Encoder的輸出狀态hm,作為初始狀态h0輸入到Decoder中。Decoder的第一個輸入是一個辨別單詞,例如[start],代表翻譯起始,除此之外不代表任何含義。有了初始狀态h0和第一個輸入[start],Decoder中的RNN(例如LSTM)即可進行計算生成第一個輸出狀态s1。在機器翻譯中,s1會輸入到一個softmax中,然後輸出預測的德語單詞z1。然後z1會作為下一時間步t2的輸入,同時s1也會作為下一時間步t2的輸入狀态,計算得出時間步t2的狀态s2,以及s2輸入softmax後預測的單詞z2。疊代這個過程,直到輸出一個終止辨別單詞,例如[end],此辨別單詞與[start]一樣,僅代表翻譯終止,除此之外不代表任何含義。
在訓練一個Seq2Seq時,仍以機器翻譯為例(英語翻譯為德語)。Encoder仍是輸入英語句子sequence,得到狀态hm(或者同時也輸出LSTM的傳輸帶向量Cm),輸入到Decoder中。Decoder第一個輸入為[start],輸入的狀态為hm,計算s1,然後得到預測值z1。在測試集中,label對應的是一句德語句子Sequence。此時,對比第1個預測值z1與德語句子Sequence的第1個單詞的差異(使用CrossEntropy(y1, z1)),得到loss。然後反向傳播更新權重參數。然後繼續計算第2個時間步t2的s2與z2,并對比第2個預測值z2與德語句子Sequence的第2個單詞的差異,得到loss并反向傳播更新參數。疊代此過程即可。
在Encoder部分中,若是輸入序列比較長,則會容易忘記最開始的内容。此時一個優化的方法就是對Encoder使用雙向RNN(例如雙向LSTM),這樣可以提高Encoder所攜帶的資訊。不過Decoder必須是單向的。
另一方面,可以使用同一個Encoder訓練多個Decoder,這個稱為多任務學習。例如,輸入英語,将其翻譯為多個其他語言。這樣Encoder的訓練資料就多了幾倍,可以将Encoder訓練的更好。雖然其中任一Decoder的訓練資料沒有增加,但仍能增強翻譯的效果,因為Encoder的效果得到了提升。
Seq2Seq模型的結構非常簡單,其實就是2個RNN模型串起來。Encoder部分處理輸入,并将所有輸入資訊儲存在一個狀态向量h中,輸入到Decoder中進行各類任務。前面提到,最開始Seq2Seq提出時,主要基于的還是RNN。不過在2015年,随着Attention問世後,極大提升了Seq2Seq模型的效果,達到了超出了傳統RNN的性能。下面我們介紹Attention機制。
Attention的論文于2015年發表,用于改進Seq2Seq模型。上面我們也提到過原始Seq2Seq模型存在的問題,那就是:處理長序列的能力有限。因為Encoder中僅有最後一個時間步的狀态hm,作為context向量輸入到Decoder中。若是輸入序列比較長,則會容易忘記sequence位置靠前的輸入。雖然前面提到的雙向LSTM作為Encoder可以在一定程度上緩解此問題,但仍未根本解決此問題。是以RNN-Based Seq2Seq僅适合于短序列(序列長度 < 20)。
Attention機制便是為了直接解決此問題而提出。使用了Attention後,在Decoding的過程中,它不會僅使用Encoder最終輸出的單個狀态hm,而是會使用到所有輸入序列的hidden states。并且Attention還會告知Decoder,應該關注Encoder中哪個狀态。Attention可以大幅提高Seq2Seq模型的準确率,但是代價是計算量非常大。
簡單地來說,Attention在Seq2Seq模型中的計算分為以下幾步:
Encoder計算産生每個時間步的Hidden State。例如輸入是x = [x1, x2, …, xm],對應每個時間步的輸出為向量h = [h1, h2, …, hm]。在原始Seq2Seq模型中,僅有最後一個狀态hm會被保留,但是在Attention + Seq2Seq中,每個中間狀态hi都會保留。
此時Decoder的第一個輸入狀态為s0,計算s0與每個[h1, h2, …, hm]的Alignment Score(它可以了解為s0與每個hi的相關性得分,這個Alignment Score 有多種計算的方法,稍後會介紹),得到向量a = [a1, a2, …, am]。
對Alignment Score(也就是向量a)做softmax處理,此時向量a的所有元素被壓縮到[0, 1] 的範圍内,且所有元素相加的和為1(也就是說,每個元素代表1個比例值)。
将向量a與h做内積得到Context Vector(上下文向量)c0 = a1 * h1 + a2 * h2 +, …, + am * hm。
Context向量c0與Decoder的輸入狀态向量s0,以及Decoder目前時間步t0的輸入x’0做拼接(concatenate)操作,得到 [c0, s0, x’0],輸入到Decoder中,得到Decoder的下一個時間步t1的狀态s1。如果這是一個機器翻譯的任務,則s1(例如輸入一個全連接配接層)會進一步用于預測此時間步t0的輸出x’1,并用于下一個時間步t1的輸入。
疊代以上2-5步的過程,直到Decoder達到指定長度或是輸出指定停止信号(例如輸出[END])
整個過程如下圖所示:
Fig. 3. ShusenWang. Seq2Seq Model with Attention[3]
從這個過程可以看到,在Decoder每次處理一個時間步輸入時,都會再周遊一遍Encoder的每個時間步的狀态。這樣便解決了Seq2Seq模型對長序列記憶力有限的問題。但同時,由于Decoder中每個時間步t都要重新計算Alignment Score,是以很大程度上增加了計算量,這便是Attention提高準确率的代價。
最後,上面提到Alignment Score進入softmax後,結果向量a的每個元素的值是一個比例,其總和為1。此時在計算Context 向量c時,向量a與hidden state向量h做的點乘操作。此時,即可視向量a中的每個元素為權重,權重ai的大小表示了對應位置的hi的重要程度(當然,這個權重是訓練過程中不斷優化後得到的),也即是說:在目前時間步ti,Encoder中各個hidden state hi對Decoder此時輸出值的影響程度。這便是我們前面提到的“Attention還會告知Decoder,應該關注Encoder中哪個狀态”。
在了解了Attention的基本邏輯後,我們繼續介紹Attention中計算Alignment Score的方法。
Bahdanau Attention是原論文[4]中提出的方法,以論文第一作者Bahdanau的名字為命名得來。它的Encoder部分為雙向RNN。Attention部分的計算方式如下表示:
Scorealignment = Wcombined * tanh (Wdecoder * Hdecoder + Wencoder * Hencoder)
這個公式中一共有3個參數矩陣Wcombined,Wdecoder和Wencoder。在計算Alignment Score時,假設目前時間步為t0,此時Hdecoder為Decoder的第一個輸入狀态s0,Hencoder為Encoder的輸出狀态向量h。Hdecoder與Hencoder分别輸入到2個全連接配接網絡層(FC層)中:
Fig. 4. Gabriel Loye. Attention Mechanism [5]
然後将其做矩陣加法後,送入tanh函數,将其範圍壓縮到(-1, 1)之間:
Fig. 5. Gabriel Loye. Attention Mechanism [5]
最後再送入到一個全連接配接層,得到最終的Alignment Score:
Fig. 6. Gabriel Lo
ye. Attention Mechanism [5]
這個過程會涉及到3個全連接配接層,是以3個參數矩陣Wcombined,Wdecoder和Wencoder 均是可訓練參數。在訓練過程中不斷優化得到最終參數值。由于這個方法中Hdecoder與Hencoder在(分别與參數矩陣做乘法後)合并時使用的是加法,是以這種方法也稱為Additive Attention。
Luong Attention 由Thang Luong 于2015年提出[6]。它相對于Bahdanau Attention有3點不同:
在Decoder中引入Attention的位置不同
在Decoder中的輸入輸出不同
計算Alignment Score的方法不同
Luong Attention的計算過程步驟為:
Encoder部分生成所有Hidden State
在Decoder中,假設目前時間步為t。使用上一時間步t-1的狀态st-1與輸出outputt-1,計算出一個新的狀态st
使用st與Encoder中的Hidden State計算Alignment Score
将Alignment Score 送入softmax,得到權重向量at
使用權重向量at與Encoder中的Hidden State做點積得到上下文向量ct
将ct與步驟2中得到的st做拼接[ct, st],輸入一個全連接配接層,得到一個新的輸出s’t,此狀态向量s’t即為目前時間步t下,Decoder輸出的真正的狀态(相對于第2步輸出的狀态st)。此狀态向量s’t可以繼而輸入到一個全連接配接層,執行所需任務(例如機器翻譯中預測下一個單詞)
從這個計算過程,可以看到:
在Decoder中引入Attention的位置不同是指:相對于Bahdanau使用的是上一時間步t-1的狀态st-1計算Alignment Score。而Luong Attention使用的是目前時間步t的狀态st
在Decoder中的輸入輸出不同是指:Luong Attention中,會先計算出一個時間步t的狀态st。但此狀态向量并非為最終的t時刻的狀态,而是用此狀态st再計算出一個新的狀态s’t,這才是Decoder在t時間步的真正狀态
在計算Alignment Score方面,Luong Attention提供了3種計算Alignment Score的方法,分别稱為dot、general以及concat方法。
Dot(點積)方法非常簡單,直接使用Hencoder與Decoder的隐藏狀态s做點乘:
Scorealignment = Hencoder * st
General方法與dot方法類似,是dot方法的一個變體,但是加了一個可訓練的權重矩陣W:
Scorealignment = W(Hencoder * st)
Concat方法與Bahdanau Attention中使用的類似,但是會先将Hencoder 與 st 相加,然後送入一個全連接配接層,是以它們共享1個權重矩陣:
Scorealignment = W * tanh(Wconbined(Hencoder + st))
這3種方法中,General方法效果最好,是以現在主要使用General方法。而由于此方法中Hencoder與Decoder狀态st在合并時使用的是乘法,是以此方法也稱為Multiplicative Attention。
最後介紹現在更常用的Alignment Score計算方法,也是Transformer模型種用的方法。此方法很簡單,涉及到2個參數矩陣WK,WQ,步驟如下:
将Encoder隐藏狀态h = [h1, h2, …, hm]與參數矩陣WK做線性變換(也就是WK * h),得到向量k = [k1, k2, …, k3]
将Decoder在時間步t-1的狀态st-1與參數矩陣WQ做線性變換(也就是 WQ * st-1),得到qt-1
使用向量k與qt-1做内積,即得到Alignment Score = [k1 * qt-1, k2 * q t-1, …, km * q t-1]。對它做softmax,即得到權重向量a。
有關這部分更詳細的内容會在介紹Transformer時進一步介紹。
這章介紹了Seq2Seq模型與Attention機制,以及Attention的3種不同實作。除此之外,Attention還有其它2種非常重要的變體:Self-Attention與Multi-Head Attention。這2個變體會在介紹Transformer模型時具體介紹。
本來預期這章會介紹Attention機制,并會開始介紹Transformer模型,但是由于寫的内容比預期要多,是以會在下一章再開始介紹Transformer模型。
[1] https://arxiv.org/pdf/1406.1078.pdf
[2]https://raw.githubusercontent.com/wangshusen/DeepLearning/master/Slides/9_RNN_6.pdf
[3] https://www.bilibili.com/video/BV1YA411G7Ep
[4] https://arxiv.org/pdf/1409.0473.pdf
[5] https://blog.floydhub.com/attention-mechanism/
[6] https://arxiv.org/abs/1508.04025