天天看點

通過代碼解讀自注意力機制

通過代碼解讀自注意力機制
【新智元導讀】NLP領域最近的快速進展離不開基于Transformer的架構,本文以圖解+代碼的形式,帶領讀者完全了解self-attention機制及其背後的數學原理,并擴充到Transformer。

BERT, RoBERTa, ALBERT, SpanBERT, DistilBERT, SesameBERT, SemBERT, MobileBERT, TinyBERT, CamemBERT……它們有什麼共同之處呢?答案不是“它們都是BERT”????。

正确答案是:self-attention????。

我們讨論的不僅是名為“BERT”的架構,更準确地說是基于Transformer的架構。基于Transformer的架構主要用于模組化語言了解任務,它避免了在神經網絡中使用遞歸,而是完全依賴于self-attention機制來繪制輸入和輸出之間的全局依賴關系。但這背後的數學原理是什麼呢?

這就是本文要講的内容。這篇文章将帶你通過一個self-attention子產品了解其中涉及的數學運算。讀完本文,你将能夠從頭開始寫一個self-attention子產品。

讓我們開始吧!

完全圖解——8步掌握self-attention

self-attention是什麼?

如果你認為self-attention與attention有相似之處,那麼答案是肯定的!它們基本上共享相同的概念和許多常見的數學運算。

一個self-attention子產品接收n個輸入,然後傳回n個輸出。這個子產品中發生了什麼呢?用外行人的話說,self-attention機制允許輸入與輸入之間彼此互動(“self”),并找出它們應該更多關注的對象(“attention”)。輸出是這些互動和注意力得分的總和。

寫一個self-attention子產品包括以下步驟

  • 準備輸入
  • 初始化權重
  • 推導key, query 和 value
  • 計算輸入1的注意力得分
  • 計算softmax
  • 将分數與值相乘
  • 将權重值相加,得到輸出1
  • 對輸入2和輸入3重複步驟4-7
注:實際上,數學運算是矢量化的,,即所有的輸入都一起經曆數學運算。在後面的代碼部分中可以看到這一點。

步驟1:準備輸入

通過代碼解讀自注意力機制

圖1.1: 準備輸入

在本教程中,我們從3個輸入開始,每個輸入的維數為4。

通過代碼解讀自注意力機制

步驟2:初始化權重

每個輸入必須有三個表示(見下圖)。這些表示稱為鍵(key,橙色)、查詢(query,紅色)和值(value,紫色)。在本例中,我們假設這些表示的維數是3。因為每個輸入的維數都是4,這意味着每組權重必須是4×3。

注:

稍後我們将看到value的次元也是輸出的次元。

通過代碼解讀自注意力機制

圖1.2:從每個輸入得出鍵、查詢和值的表示

為了得到這些表示,每個輸入(綠色)都乘以一組鍵的權重、一組查詢的權重,以及一組值的權重。在本示例中,我們将三組權重“初始化”如下。

key的權重:

通過代碼解讀自注意力機制

query的權重:

通過代碼解讀自注意力機制

value的權重:

通過代碼解讀自注意力機制
在神經網絡設定中,這些權重通常是很小的數字,使用适當的随機分布(例如高斯、Xavier和Kaiming分布)進行随機初始化。

步驟3:推導鍵、查詢和值

現在,我們有了三組權重,讓我們實際擷取每個輸入的鍵、查詢和值表示。

輸入1的鍵表示:

通過代碼解讀自注意力機制

使用相同的權重集合得到輸入2的鍵表示:

通過代碼解讀自注意力機制

使用相同的權重集合得到輸入3的鍵表示:

通過代碼解讀自注意力機制

一種更快的方法是對上述操作進行矢量化:

通過代碼解讀自注意力機制
通過代碼解讀自注意力機制

圖1.3a:從每個輸入推導出鍵表示

同樣的方法,可以擷取每個輸入的值表示:

通過代碼解讀自注意力機制
通過代碼解讀自注意力機制

圖1.3b:從每個輸入推導出值表示

最後,得到查詢表示

通過代碼解讀自注意力機制
通過代碼解讀自注意力機制

圖1.3b:從每個輸入推導出查詢表示

在實踐中,偏差向量(bias vector )可以添加到矩陣乘法的乘積。

步驟4:計算輸入1的attention scores

通過代碼解讀自注意力機制

圖1.4:從查詢1中計算注意力得分(藍色)

為了獲得注意力得分,我們首先在輸入1的查詢(紅色)和所有鍵(橙色)之間取一個點積。因為有3個鍵表示(因為有3個輸入),我們得到3個注意力得分(藍色)。

通過代碼解讀自注意力機制
注:現在隻使用Input 1中的查詢。稍後,我們将對其他查詢重複相同的步驟。

步驟5:計算softmax

通過代碼解讀自注意力機制

圖1.5:Softmax注意力評分(藍色)

在所有注意力得分中使用softmax(藍色)。

通過代碼解讀自注意力機制

步驟6:将得分和值相乘

通過代碼解讀自注意力機制

圖1.6:由值(紫色)和分數(藍色)的相乘推導出權重值表示(黃色)

每個輸入的softmaxed attention 分數(藍色)乘以相應的值(紫色)。結果得到3個對齊向量(黃色)。在本教程中,我們将它們稱為權重值。

通過代碼解讀自注意力機制

步驟7:将權重值相加得到輸出1

圖1.7:将所有權重值(黃色)相加,得到輸出1(深綠色)

将所有權重值(黃色)按元素指向求和:

通過代碼解讀自注意力機制

結果向量[2.0,7.0,1.5](深綠色)是輸出1,該輸出基于輸入1與所有其他鍵(包括它自己)進行互動的查詢表示。

步驟8:重複輸入2和輸入3

現在,我們已經完成了輸出1,我們對輸出2和輸出3重複步驟4到7。接下來相信你可以自己操作了????????。

通過代碼解讀自注意力機制

圖1.8:對輸入2和輸入3重複前面的步驟

代碼上手

這是PyTorch代碼????,PyTorch是Python的一個流行的深度學習架構。

import torch                  x = [              [1, 0, 1, 0], # Input 1              [0, 2, 0, 2], # Input 2              [1, 1, 1, 1]  # Input 3              ]              x = torch.tensor(x, dtype=torch.float32)           
w_key = [              [0, 0, 1],              [1, 1, 0],              [0, 1, 0],              [1, 1, 0]              ]              w_query = [              [1, 0, 1],              [1, 0, 0],              [0, 0, 1],              [0, 1, 1]              ]              w_value = [              [0, 2, 0],              [0, 3, 0],              [1, 0, 3],              [1, 1, 0]              ]              w_key = torch.tensor(w_key, dtype=torch.float32)              w_query = torch.tensor(w_query, dtype=torch.float32)              w_value = torch.tensor(w_value, dtype=torch.float32)           

步驟3: 推導鍵、查詢和值

keys = x @ w_key              querys = x @ w_query              values = x @ w_value                  print(keys)              # tensor([[0., 1., 1.],              #         [4., 4., 0.],              #         [2., 3., 1.]])                  print(querys)              # tensor([[1., 0., 2.],              #         [2., 2., 2.],              #         [2., 1., 3.]])                  print(values)              # tensor([[1., 2., 3.],              #         [2., 8., 0.],              #         [2., 6., 3.]])           

步驟4:計算注意力得分

attn_scores = querys @ keys.T                  # tensor([[ 2.,  4.,  4.],  # attention scores from Query 1              #         [ 4., 16., 12.],  # attention scores from Query 2              #         [ 4., 12., 10.]]) # attention scores from Query 3           
from torch.nn.functional import softmax                  attn_scores_softmax = softmax(attn_scores, dim=-1)              # tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],              #         [6.0337e-06, 9.8201e-01, 1.7986e-02],              #         [2.9539e-04, 8.8054e-01, 1.1917e-01]])                  # For readability, approximate the above as follows              attn_scores_softmax = [              [0.0, 0.5, 0.5],              [0.0, 1.0, 0.0],              [0.0, 0.9, 0.1]              ]              attn_scores_softmax = torch.tensor(attn_scores_softmax)           
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]                  # tensor([[[0.0000, 0.0000, 0.0000],              #          [0.0000, 0.0000, 0.0000],              #          [0.0000, 0.0000, 0.0000]],              #               #         [[1.0000, 4.0000, 0.0000],              #          [2.0000, 8.0000, 0.0000],              #          [1.8000, 7.2000, 0.0000]],              #               #         [[1.0000, 3.0000, 1.5000],              #          [0.0000, 0.0000, 0.0000],              #          [0.2000, 0.6000, 0.3000]]])           

步驟7:求和權重值

outputs = weighted_values.sum(dim=0)                  # tensor([[2.0000, 7.0000, 1.5000],  # Output 1              #         [2.0000, 8.0000, 0.0000],  # Output 2              #         [2.0000, 7.8000, 0.3000]]) # Output 3           

擴充到Transformer

那麼,接下來怎麼辦呢?Transformer!

的确,我們生活在一個深度學習研究和高計算資源的激動人心的時代。Transformer是Attention is All You Need裡面提出的,最初用于執行神經機器翻譯。研究人員在此基礎上進行了重組、切割、添加和擴充,并将其應用到更多的語言任務中。

在這裡,我将簡要地介紹如何将self-attention擴充到Transformer架構。

在self-attention子產品中:

  • Dimension
  • Bias

self-attention子產品的輸入:

  • Embedding module
  • Positional encoding
  • Truncating
  • Masking

增加更多的self-attention子產品:

  • Multihead
  • Layer stacking
  • self-attention子產品之間的子產品:
  • Linear transformations
  • LayerNorm

這就是所有了!希望你覺得内容簡單易懂。

參考文獻:

Attention Is All You Need 

https://arxiv.org/abs/1706.03762

上一篇: 測試概述
下一篇: assert的使用

繼續閱讀