天天看點

位置編碼

為什麼需要位置編碼

在transformer中使用了位置編碼,為什麼需要位置編碼。因為對于transformer中的注意力機制而言,交換兩個單詞,并不會影響注意力的計算,也就是說這裡的注意力是對單詞位置不敏感的,而單詞之間的位置資訊往往是很重要的,是以考慮使用位置編碼。

絕對位置編碼

三角函數位置編碼

transformer使用的位置編碼。基本公式:

\[p_{k,2i} = sin(\frac{k}{10000^{\frac{2i}{d}}}) \\

p_{k,2i+1} = cos(\frac{k}{10000^{\frac{2i}{d}}})

\]

\(p_{k}\)表示序列中第k個單詞,2i及2i+1是其的兩個分量,也就是說,第k個位置編碼是由兩部分構成的。假設句子長度為512,那麼位置編碼向量次元就是512×2。那麼為什麼會使用這種位置編碼表示呢?首先三角函數有以下性質:

\[sin(\alpha+\beta) = sin\alpha\cos\beta+cos\alpha\sin\beta \\

cos(\alpha+\beta) = cos\alpha\cos\beta-sin\alpha\sin\beta

\]

那麼:

\[p_{m}=[p_{m,2i},p_{m, 2i}] \\

p_{m}=[sin(\frac{m}{10000^{\frac{2i}{d}}}), cos(\frac{m}{10000^{\frac{2i}{d}}})] \\

p_{m+k}=[p_{m+k,2i},p_{m+k, 2i}] \\

p_{m+k}=[sin(\frac{m+k}{10000^{\frac{2i}{d}}}), cos(\frac{m+k}{10000^{\frac{2i}{d}}})] \\

\]

我們把\(\frac{1}{10000^{\frac{2i}{d}}}\)記為a,則有:

\[p_{m+k} = [sinamcosak+cosamsinak, cosamcosak-sinamsinak] \\

p_{m+k}=

\left[

\begin{matrix}

cosak&sinak\\

-sinak&cosak\\

\end{matrix}

\right]

\left[

\begin{matrix}

sinam \\

cosam \\

\end{matrix}

\right]

= \left[

\begin{matrix}

cosak&sinak\\

-sinak&cosak\\

\end{matrix}

\right]p_{m}

\]

也就是說第m+k個位置的位置編碼可以由第m個位置表示。另有:

\[P_{t+k} = R_{k}P_{t} \\

P_{t+k1+k2} = R_{k1+k2}P_{t}=R_{k1}R_{k2}P_{t} \\

則有:\\

R_{k1+k2} =R_{k1}R_{k2} \\

R_{k1-k2} =R_{k1}R_{-k2} \\

因為:\\

-sin\alpha=sin-\alpha, cos\alpha=cos-\alpha \\

是以:\\

R_{-k2}=(R_{k2})^{T} \\

最終:\\

R_{k1-k2}=R_{k1}(R_{k2})^{T} 或者 R_{k2-k1}=R_{k2}(R_{k1})^{T}

\]

參考實作:

class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)],
                         requires_grad=False)
        return self.dropout(x)

           

可學習的位置編碼

這一種位置編碼就是Bert模型采用的。為什麼bert不用transformer的三角函數編碼,因為bert訓練采用了更大的預料,使用可學習的位置編碼效果可能更好。

遞歸式位置編碼

這裡摘錄蘇劍林的文章:

原則上來說,RNN模型不需要位置編碼,它在結構上就自帶了學習到位置資訊的可能性(因為遞歸就意味着我們可以訓練一個“數數”模型),是以,如果在輸入後面先接一層RNN,然後再接Transformer,那麼理論上就不需要加位置編碼了。同理,我們也可以用RNN模型來學習一種絕對位置編碼,比如從一個向量\(p_{0}\)出發,通過遞歸格式 \(p_{k+1}=f(p_{k})\) 來得到各個位置的編碼向量。

ICML 2020的論文《Learning to Encode Position for Transformer with Continuous Dynamical Model》把這個思想推到了極緻,它提出了用微分方程(ODE)\(dp_{t}/d_{t=h(p_{t},t)}\) 的方式來模組化位置編碼,該方案稱之為FLOATER。顯然,FLOATER也屬于遞歸模型,函數\(h(p_{t},t)\)可以通過神經網絡來模組化,是以這種微分方程也稱為神經微分方程,關于它的工作最近也逐漸多了起來。

理論上來說,基于遞歸模型的位置編碼也具有比較好的外推性,同時它也比三角函數式的位置編碼有更好的靈活性(比如容易證明三角函數式的位置編碼就是FLOATER的某個特解)。但是很明顯,遞歸形式的位置編碼犧牲了一定的并行性,可能會帶速度瓶頸。

相對位置編碼

直接去看蘇劍林的文章:https://zhuanlan.zhihu.com/p/352898810

旋轉位置編碼:

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

context_outputs = torch.rand((32, 512, 768))
last_hidden_state = context_outputs  # 這裡的context_outputs是bert的輸出
# # last_hidden_state:[batch_size, seq_len, hidden_size]
batch_size = last_hidden_state.size()[0]
seq_len = last_hidden_state.size()[1]

hidden_size = 768
ent_type_size = 10
inner_dim = 64
# self.ent_type_size表示的是實體的總數, inner_dim自定義為64
# outputs:(batch_size, seq_len, ent_type_size*inner_dim*2)=[32, 512, 10*64*2]
outputs = nn.Linear(hidden_size, ent_type_size * inner_dim * 2)(last_hidden_state)
# 得到10個[32, 512, 64*2]
outputs = torch.split(outputs, inner_dim * 2, dim=-1)
# [32, 512, 10, 64*2]
outputs = torch.stack(outputs, dim=-2)
# qw和kw都是:[32, 512, 10, 64]
qw, kw = outputs[..., :inner_dim], outputs[..., inner_dim:]
"""這下面就是旋轉位置編碼主代碼"""


def sinusoidal_position_embedding(batch_size, seq_len, output_dim):
    """這裡是最初得正餘弦位置編碼"""
    # position_ids:[512, 1]
    position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1)
    # [32],從0-31
    indices = torch.arange(0, output_dim // 2, dtype=torch.float)
    # 10000^(-[0,...,31]/64)
    indices = torch.pow(10000, -2 * indices / output_dim)
    # [512, 32]
    embeddings = position_ids * indices
    # torch.Size([512, 32, 2])
    embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
    # [32, 512, 32, 2]
    embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape))))
    # [32, 512, 64]
    embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim))
    return embeddings


pos_emb = sinusoidal_position_embedding(batch_size,
                                        seq_len,
                                        output_dim=inner_dim)
# 取奇數位,奇數位是cos
# repeat_interleave重複張量得元素
# torch.Size([32, 512, 1, 64])
cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1)
# torch.Size([32, 512, 1, 64])
# 偶數位是sin
sin_pos = pos_emb[..., None,::2].repeat_interleave(2, dim=-1)

# torch.Size([32, 512, 10, 32, 2])
# 重新排列
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], -1)
# [32, 512, 10, 64]
qw2 = qw2.reshape(qw.shape)
# [32, 512, 10, 64] * [32, 512, 1, 64] + [32, 512, 10, 64] * [32, 512, 1, 64]
qw = qw * cos_pos + qw2 * sin_pos  # 這就是旋轉位置編碼得最終結果
kw2 = torch.stack([-kw[..., 1::2], kw[...,::2]], -1)
kw2 = kw2.reshape(kw.shape)
kw = kw * cos_pos + kw2 * sin_pos
           

參考

蘇劍林-讓研究人員絞盡腦汁的Transformer位置編碼

三角函數位置編碼實作

六種位置編碼的代碼實作及性能實驗

繼續閱讀