天天看點

PyTorch快餐教程2019 (2) - Multi-Head AttentionPyTorch快餐教程2019 (2) - Multi-Head Attention

PyTorch快餐教程2019 (2) - Multi-Head Attention

上一節我們為了讓一個完整的語言模型跑起來,可能給大家帶來的學習負擔過重了。沒關系,我們這一節開始來還上節沒講清楚的債。

還記得我們上節提到的兩個Attention嗎?

上節我們給大家一個印象,現在我們正式開始介紹其原理。

Scaled Dot-Product Attention

首先說Scaled Dot-Product Attention,其計算公式為:

$

Attention(Q,K,V)=softmax(frac{QK^T}{sqrt{d_k}})V

Q乘以K的轉置,再除以$d_k$的平方根進行縮放,經過一個可選的Mask,經過softmax之後,再與V相乘。

用代碼實作如下:

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn           

Multi-Head Attention

有了縮放點積注意力機制之後,我們就可以來定義多頭注意力。

MultiHead(Q,K,V)=concat(head_1,...,head_n)W^O

其中,$head_i=Attention(QW_i^Q,KW_i^K,VW_i^V)$

這個Attention是我們上面介紹的Scaled Dot-Product Attention.

這些W都是要訓練的參數矩陣。

W_i^Qin mathbb{R}^{d_{model} times d_k},

W_i^Kinmathbb{R}^{d_{model} times d_k}, W_i^Vinmathbb{R}^{d_{model} times d_v}, W_oinmathbb{R}^{hd_v times d_{model}}

h是multi-head中的head數。在《Attention is all you need》論文中,h取值為8。

$d_k=d_v=d_{model}/h=64$

這樣我們需要的參數就是d_model和h.

大家看公式有點要暈的節奏,别怕,我們上代碼:

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "初始化時指定頭數h和模型次元d_model"
        super(MultiHeadedAttention, self).__init__()
        # 二者是一定整除的
        assert d_model % h == 0
        # 按照文中的簡化,我們讓d_v與d_k相等
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)           

其中,clones是複制幾個一模一樣的模型的函數,其定義如下:

def clones(module, N):
    "生成n個相同的層"
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])           

Attention的邏輯主要分為4步。第一步是計算一下mask。

def forward(self, query, key, value, mask=None):
        "實作多頭注意力模型"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)           

第二步是将這一批次的資料進行變形 d_model => h x d_k

query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]           

第三步,針對所有變量計算scaled dot product attention

x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)           

最後,将attention計算結果串聯在一起,其實對張量進行一次變形:

x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)           

再看一種寫法鞏固一下

上面引用的代碼來自:

http://nlp.seas.harvard.edu/2018/04/03/attention.html

為了加深印象,我們再看另一種寫法。

這個的命名更偏工程,d_model叫做hid_dim,h叫做n_heads,但是意思是一回事。

class SelfAttention(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()

        self.hid_dim = hid_dim
        self.n_heads = n_heads
        
        # d_model // h 仍然是要能整除,換個名字仍然意義不變
        assert hid_dim % n_heads == 0

        self.w_q = nn.Linear(hid_dim, hid_dim)
        self.w_k = nn.Linear(hid_dim, hid_dim)
        self.w_v = nn.Linear(hid_dim, hid_dim)

        self.fc = nn.Linear(hid_dim, hid_dim)

        self.do = nn.Dropout(dropout)

        self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)           

下面是處理資料的過程:

def forward(self, query, key, value, mask=None):

# Q,K,V計算與變形:

        bsz = query.shape[0]

        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)

        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim //
                   self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim //
                   self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim //
                   self.n_heads).permute(0, 2, 1, 3)

# Q, K相乘除以scale,這是計算scaled dot product attention的第一步

        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

# 如果沒有mask,就生成一個

        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)

# 然後對Q,K相乘的結果計算softmax加上dropout,這是計算scaled dot product attention的第二步:

        attention = self.do(torch.softmax(energy, dim=-1))

# 第三步,attention結果與V相乘

        x = torch.matmul(attention, V)

# 最後将多頭排列好,就是multi-head attention的結果了

        x = x.permute(0, 2, 1, 3).contiguous()

        x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))

        x = self.fc(x)

        return x           

第二種實作取自:

https://github.com/bentrevett/pytorch-seq2seq

繼續閱讀