天天看點

Transformer-XL解讀(論文 + PyTorch源碼)

前言

目前在NLP領域中,處理語言模組化問題有兩種最先進的架構:RNN和Transformer。RNN按照序列順序逐個學習輸入的單詞或字元之間的關系,而Transformer則接收一整段序列,然後使用self-attention機制來學習它們之間的依賴關系。這兩種架構目前來看都取得了令人矚目的成就,但它們都局限在捕捉長期依賴性上。

為了解決這一問題,CMU聯合Google Brain在2019年1月推出的一篇新論文《Transformer-XL:Attentive Language Models beyond a Fixed-Length Context》同時結合了RNN序列模組化和Transformer自注意力機制的優點,在輸入資料的每個段上使用Transformer的注意力子產品,并使用循環機制來學習連續段之間的依賴關系。Transformer-XL在多種語言模組化資料集(如單詞級别的enwik8和字元級别的text8)上實作了目前的SoTA效果,且該模型在推理階段速度更快,比之前最先進的利用Transformer進行語言模組化的方法快300~1800倍。 同時,該論文也放出了其配套源碼(包括TensorFlow和PyTorch的)、預訓練模型及在各個資料集上訓練的超參數,可以說是非常良心了~造福我等伸手黨!

本文将主要針對模型原理及其PyTorch實作進行逐一對照解讀,因筆者能力有限,如有不詳盡之處,可移步文末的傳送門進行詳細閱讀,并歡迎指出~

文章目錄

    • 前言
    • 一. 回顧Transformer
    • 二. vanilla Transformer
    • 三. Transformer-XL
      • 1. 引入循環機制
      • 2. 相對位置編碼
      • 3. 整體計算公式
    • 四. PyTorch實作
    • 五. 實驗結果
      • 1. 語言模組化名額
      • 2. 兩個創新點的優勢
      • 3. 測試階段的速度
    • 六. 總結
      • 1. 模型特點
      • 2. 優點
      • 3. 不足
    • 傳送門

一. 回顧Transformer

在NLP領域中,一種對語言模組化的最常用模型就是RNN,它可以捕捉單詞之間的依賴關系。但因為梯度消失和爆炸的問題,RNN變得非常難以訓練,LSTM單元和梯度裁剪方法的提出也不足以解決此類問題。同時RNN網絡的計算速度往往很慢,其學習長期依賴的能力也較為有限(論文中提到,LSTM語言模型平均隻能模組化200個上下文詞語)。

2017年6月,Google Brain在論文《Attention Is All You Need》中提出的Transformer架構,完全摒棄了RNN的循環機制,采用一種self-attention的方式進行全局處理。其接收一整段序列,并使用三個可訓練的權重矩陣——Query、Key和Value來一次性學習輸入序列中各個部分之間的依賴關系。Transformer網絡由多個層組成,每個層都由多頭注意力機制和前饋網絡構成。由于在全局進行注意力機制的計算,忽略了序列中最重要的位置資訊。Transformer為輸入添加了位置編碼(Positional Encoding),使用正弦函數完成,為每個部分的位置生成位置向量,不需要學習,用于幫助網絡學習其位置資訊。其示意如下圖所示:

Transformer-XL解讀(論文 + PyTorch源碼)

有關Transformer的更深入讨論,可參考筆者之前的部落格:

Transformer(論文 + PyTorch源碼解讀)

二. vanilla Transformer

為何要提這個模型?因為Transformer-XL是基于這個模型進行的改進。

Al-Rfou等人基于Transformer提出了一種訓練語言模型的方法( https://arxiv.org/abs/1808.04444 ),來根據之前的字元預測片段中的下一個字元。例如,它使用 x 1 , x 2 , . . . , x n − 1 x_1, x_2, ..., x_{n-1} x1​,x2​,...,xn−1​預測字元 x n x_n xn​,而在 x n x_n xn​之後的序列則被mask掉。論文中使用64層模型,并僅限于處理 512個字元這種相對較短的輸入,是以它将輸入分成段,并分别從每個段中進行學習,如下圖所示。 在測試階段如需處理較長的輸入,該模型會在每一步中将輸入向右移動一個字元,以此實作對單個字元的預測。

Transformer-XL解讀(論文 + PyTorch源碼)

該模型在常用的資料集如enwik8和text8上的表現比RNN模型要好,但它仍有以下兩個缺點:

a. 上下文長度受限:字元之間的最大依賴距離受輸入長度的限制,模型看不到出現在幾個句子之前的單詞。

b. 上下文碎片:對于長度超過512個字元的文本,都是從頭開始單獨訓練的。段與段之間沒有上下文依賴性,會讓訓練效率低下,也會影響模型的性能。

c. 推理速度慢:在測試階段,每次預測下一個單詞,都需要重新建構一遍上下文,并從頭開始計算,這樣的計算速度非常慢。

三. Transformer-XL

Transformer-XL架構在vanilla Transformer的基礎上引入了兩點創新:循環機制(Recurrence Mechanism)和相對位置編碼(Relative Positional Encoding),以克服vanilla Transformer的缺點。與vanilla Transformer相比,Transformer-XL的另一個優勢是它可以被用于單詞級和字元級的語言模組化。

1. 引入循環機制

與vanilla Transformer的基本思路一樣,Transformer-XL仍然是使用分段的方式進行模組化,但其與vanilla Transformer的本質不同是在于引入了段與段之間的循環機制,使得目前段在模組化的時候能夠利用之前段的資訊來實作長期依賴性。如下圖所示:

Transformer-XL解讀(論文 + PyTorch源碼)

在訓練階段,處理後面的段時,每個隐藏層都會接收兩個輸入:

  1. 該段的前面隐藏層的輸出,與vanilla Transformer相同(上圖的灰色線)。
  2. 前面段的隐藏層的輸出(上圖的綠色線),可以使模型建立長期依賴關系。

這兩個輸入會被拼接,然後用于計算目前段的Key和Value矩陣。對于某個段的某一層的具體計算公式如下:

Transformer-XL解讀(論文 + PyTorch源碼)

其中, τ \tau τ表示第幾段, n n n表示第幾層, h h h表示隐層的輸出。 S G ( ⋅ ) SG(·) SG(⋅)表示停止計算梯度, [ h u ∘ h v ] [h_u \circ h_v] [hu​∘hv​]表示在長度次元上的兩個隐層的拼接, W . W_. W.​是模型參數。乍一看與Transformer中的計算公式很像,唯一關鍵的不同就在于Key和Value矩陣的計算上,即 k τ + 1 n k_{\tau+1}^n kτ+1n​和 v τ + 1 n v_{\tau + 1}^n vτ+1n​,它們基于的是擴充後的上下文隐層狀态 h ~ τ + 1 n − 1 \tilde{h}_{\tau+1}^{n-1} h~τ+1n−1​進行計算, h τ n − 1 {h}_{\tau}^{n-1} hτn−1​是之前段的緩存。

原則上隻要GPU記憶體允許,該方法可以利用前面更多段的資訊,測試階段也可以獲得更長的依賴。

在測試階段,與vanilla Transformer相比,其速度也會更快。在vanilla Transformer中,一次隻能前進一個step,并且需要重新建構段,并全部從頭開始計算;而在Transformer-XL中,每次可以前進一整個段,并利用之前段的資料來預測目前段的輸出。

2. 相對位置編碼

在Transformer中,一個重要的地方在于其考慮了序列的位置資訊。在分段的情況下,如果僅僅對于每個段仍直接使用Transformer中的位置編碼,即每個不同段在同一個位置上的表示使用相同的位置編碼,就會出現問題。比如,第 i − 2 i-2 i−2段和第 i − 1 i-1 i−1段的第一個位置将具有相同的位置編碼,但它們對于第 i i i段的模組化重要性顯然并不相同(例如第 i − 2 i-2 i−2段中的第一個位置重要性可能要低一些)。是以,需要對這種位置進行區分。

論文對于這個問題,提出了一種新的位置編碼的方式,即會根據詞之間的相對距離而非像Transformer中的絕對位置進行編碼。在Transformer中,第一層的計算查詢 q i T q_i^T qiT​和鍵 k j k_j kj​之間的attention分數的方式為:

Transformer-XL解讀(論文 + PyTorch源碼)

其中, E x i E_{x_i} Exi​​是詞 i i i的embedding, E x j E_{x_j} Exj​​是詞 j j j的embedding, U i U_i Ui​和 U j U_j Uj​是位置向量,這個式子實際上是 ( W q ( E x i + U i ) ) T ⋅ ( W k ( E x j + U j ) ) (W_q(E_{x_i}+U_i))^T·(W_k(E_{x_j}+U_j)) (Wq​(Exi​​+Ui​))T⋅(Wk​(Exj​​+Uj​))的展開,就是Transformer中的标準格式。

在Transformer-XL中,對上述的attention計算方式進行了變換,轉為相對位置的計算,而且不僅僅在第一層這麼計算,在每一層都是這樣計算。

Transformer-XL解讀(論文 + PyTorch源碼)

對比來看,主要有三點變化:

  1. 在(b)和(d)這兩項中,将所有絕對位置向量 U j U_j Uj​都轉為相對位置向量 R i − j R_{i-j} Ri−j​,與Transformer一樣,這是一個固定的編碼向量,不需要學習。
  2. 在(c)這一項中,将查詢的 U i T W q T U_i^TW_q^T UiT​WqT​向量轉為一個需要學習的參數向量 u u u,因為在考慮相對位置的時候,不需要查詢的絕對位置 i i i,是以對于任意的 i i i,都可以采用同樣的向量。同理,在(d)這一項中,也将查詢的 U i T W q T U_i^TW_q^T UiT​WqT​向量轉為另一個需要學習的參數向量 v v v。
  3. 将鍵的權重變換矩陣 W k W_k Wk​轉為 W k , E W_{k, E} Wk,E​和 W k , R W_{k, R} Wk,R​,分别作為content-based key vectors和location-based key vectors。

從另一個角度來解讀這個公式的話,可以将attention的計算分為如下四個部分:

a. 基于内容的“尋址”,即沒有添加原始位置編碼的原始分數。

b. 基于内容的位置偏置,即相對于目前内容的位置偏差。

c. 全局的内容偏置,用于衡量key的重要性。

d. 全局的位置偏置,根據query和key之間的距離調整重要性。

3. 整體計算公式

結合上面兩個創新點,将Transformer-XL模型的整體計算公式整理如下,這裡考慮一個N層的隻有一個注意力頭的模型:

Transformer-XL解讀(論文 + PyTorch源碼)

其中, τ \tau τ代表第幾段, n n n代表第幾層, h τ 0 : = E s τ h_\tau^0 := E_{s_\tau} hτ0​:=Esτ​​定義為第 τ \tau τ段的詞向量序列。值得一提的是,計算 A A A矩陣的時候,需要對所有的 i − j i-j i−j計算 W k , R n R i − j W_{k,R}^nR_{i-j} Wk,Rn​Ri−j​,如果直接按照公式計算的話,計算時間是 O ( l e n g t h ) 2 O(length)^2 O(length)2,而實際上 i − j i-j i−j的範圍隻從0 ~ length,是以可以先計算好這length個向量,然後在實際計算 A A A矩陣時直接取用即可。

具體的,設 M M M和 L L L分别為memory和目前段序列的長度,則 i − j i-j i−j的範圍也就為0 ~ M + L − 1 M + L - 1 M+L−1。下面的 Q Q Q矩陣中的每一行都代表着 W k , R R i − j W_{k,R}R_{i-j} Wk,R​Ri−j​中一個 i − j i-j i−j的可能性,即 Q k = W k , R R M + L − 1 − k Q_k = W_{k, R} R_{M+L-1-k} Qk​=Wk,R​RM+L−1−k​。

Transformer-XL解讀(論文 + PyTorch源碼)

則對于上面公式中的(b)項,即 q i T W k , R R i − j q_i^TW_{k,R}R_{i-j} qiT​Wk,R​Ri−j​,其構成的所有可能向量的矩陣為 B B B矩陣,其形狀為 L ∗ ( M + L ) L * (M + L) L∗(M+L),這是我們最終需要的(b)項的attention結果。

Transformer-XL解讀(論文 + PyTorch源碼)

我們進一步定義 B ~ \tilde{B} B~矩陣為如下:

Transformer-XL解讀(論文 + PyTorch源碼)

可見,需要的 B B B矩陣的每一行隻是 B ~ \tilde{B} B~的向左shift而已。是以,可以直接利用矩陣乘法計算 B ~ \tilde{B} B~即可。設 R i − j R_{i-j} Ri−j​的次元為 d R d_R dR​, q i q_i qi​的次元為 d q d_q dq​, W k , R W_{k,R} Wk,R​矩陣的次元為 d q ∗ d R d_q * d_R dq​∗dR​,則直接計算矩陣B的時間複雜度為 2 ∗ d q ∗ d R ∗ L ∗ ( M + L ) 2* d_q * d_R * L * (M+L) 2∗dq​∗dR​∗L∗(M+L),而計算 B ~ \tilde{B} B~的時間複雜度為 L ∗ d q ∗ ( M + L ) + d q ∗ d R ∗ ( M + L ) L * d_q * (M + L) + d_q * d_R * (M + L) L∗dq​∗(M+L)+dq​∗dR​∗(M+L),計算量明顯不是一個量級(後者要快很多)。

同理,對于(d)項來說,可以對所有的 i − j i-j i−j定義需要的矩陣 D D D為 L ∗ ( M + L ) L * (M+L) L∗(M+L):

Transformer-XL解讀(論文 + PyTorch源碼)

可以用如下的 d ~ \tilde{d} d~來進行shift得到:

Transformer-XL解讀(論文 + PyTorch源碼)

其中 Q Q Q矩陣已經計算過了,也可以在這一步減少計算量。

四. PyTorch實作

筆者在這裡主要研究的是核心模型部分,将針對關鍵的實作細節進行剖析,想要看完整代碼的讀者請戳這裡。

  1. 首先來看RelativePositionalEmbedding部分。
class PositionalEmbedding(nn.Module):
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()
        self.demb = demb
        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))

    def forward(self, pos_seq):
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
        return pos_emb[:,None,:]
           

這裡的

demb

是相對位置編碼的次元,

pos_seq

是序列的位置向量,在代碼裡面是

torch.arange(klen-1, -1, -1.0)

,其中的

klen

mlen+qlen

,從名稱和之前的原理介紹可知這裡的

mlen

是memory的長度,

qlen

是query的長度,這兩者組成了key的長度。最終傳回的即是 R R R向量矩陣,可見是不需要學習的。

  1. 接着來看MultiHeadAttention的部分,為了叙述友善,這裡的MultiHeadAttn是源代碼中的RelMultiHeadAttn和RelPartialLearnableMultiHeadAttn的整合,也即一層self-attention的計算方式。
class MultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
        super(MultiHeadAttn, self).__init__()

		self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

        self.layer_norm = nn.LayerNorm(d_model)

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm

        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)

		def _rel_shift(self, x, zero_triu=False):
	        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
	                               device=x.device, dtype=x.dtype)
	        x_padded = torch.cat([zero_pad, x], dim=1)
	
	        x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
	
	        x = x_padded[1:].view_as(x)
	
	        if zero_triu:
	            ones = torch.ones((x.size(0), x.size(1)))
	            x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
	
	        return x

        def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
	        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
	
	        if mems is not None:
	            cat = torch.cat([mems, w], 0)
	            if self.pre_lnorm:
	                w_heads = self.qkv_net(self.layer_norm(cat))
	            else:
	                w_heads = self.qkv_net(cat)
	            r_head_k = self.r_net(r)
	
	            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
	            w_head_q = w_head_q[-qlen:]
	        else:
	            if self.pre_lnorm:
	                w_heads = self.qkv_net(self.layer_norm(w))
	            else:
	                w_heads = self.qkv_net(w)
	            r_head_k = self.r_net(r)
	
	            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
	
	        klen = w_head_k.size(0)
	
	        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
	        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
	        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
	
	        r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # qlen x n_head x d_head
	
	        #### compute attention score
	        rw_head_q = w_head_q + r_w_bias                                         # qlen x bsz x n_head x d_head
	        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
	
	        rr_head_q = w_head_q + r_r_bias
	        BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head
	        BD = self._rel_shift(BD)
	
	        # [qlen x klen x bsz x n_head]
	        attn_score = AC + BD
	        attn_score.mul_(self.scale)
	
	        #### compute attention probability
	        if attn_mask is not None and attn_mask.any().item():
	            if attn_mask.dim() == 2:
	                attn_score = attn_score.float().masked_fill(
	                    attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
	            elif attn_mask.dim() == 3:
	                attn_score = attn_score.float().masked_fill(
	                    attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)
	
	        # [qlen x klen x bsz x n_head]
	        attn_prob = F.softmax(attn_score, dim=1)
	        attn_prob = self.dropatt(attn_prob)
	
	        #### compute attention vector
	        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
	
	        # [qlen x bsz x n_head x d_head]
	        attn_vec = attn_vec.contiguous().view(
	            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
	
	        ##### linear projection
	        attn_out = self.o_net(attn_vec)
	        attn_out = self.drop(attn_out)
	
	        if self.pre_lnorm:
	            ##### residual connection
	            output = w + attn_out
	        else:
	            ##### residual connection + layer normalization
	            output = self.layer_norm(w + attn_out)
	
	        return output
           

其中

n_head,d_model,d_head

分别表示注意力頭的個數,模型的隐層次元,每個頭的隐層次元。

qkv_net

是用于計算query、key和value變換的參數矩陣 W q , W k , E , W v W_{q}, W_{k,E}, W_{v} Wq​,Wk,E​,Wv​,與标準的Transformer中一緻,

o_net

是用于将所有注意力頭的結果拼接後再變換到模型次元的參數矩陣,

layer_norm

是LayerNormalization層,

r_net

是用于計算relative position embedding變換的參數矩陣 W k , R W_{k,R} Wk,R​。

在前向計算的過程中,

w

r

分别是上一層的輸出以及RelativePositionEmbedding,

r_w_bias

r_r_bias

分别是 u u u向量和 v v v向量,

AC

是前面公式中的(a)項和(c)項,

BD

是前面公式中的(b)項和(d)項,根據前面講的快速計算帶有相對位置的項,這裡的

BD

需要進行偏移,即

_rel_shift

,經過筆者的演算,發現這裡經過此函數後的BD并不是想要的 B B B矩陣,其在 B B B矩陣的(M+1)對角線(設主對角線為0,正數即為向右上偏移的量)的右上還有元素,不過後面緊接着就進行了mask。這裡的

attn_mask

即為

torch.triu(word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]

。再往後就是标準的Transformer中的add&norm環節了,就不再贅述。

  1. 最後來看memory的更新過程:
def _update_mems(self, hids, mems, qlen, mlen):
    # does not deal with None
    if mems is None: return None

    # mems is not None
    assert len(hids) == len(mems), 'len(hids) != len(mems)'

    # There are `mlen + qlen` steps that can be cached into mems
    # For the next step, the last `ext_len` of the `qlen` tokens
    # will be used as the extended context. Hence, we only cache
    # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
    # to `mlen + qlen - self.ext_len`.
    with torch.no_grad():
        new_mems = []
        end_idx = mlen + max(0, qlen - 0 - self.ext_len)
        beg_idx = max(0, end_idx - self.mem_len)
        for i in range(len(hids)):

            cat = torch.cat([mems[i], hids[i]], dim=0)
            new_mems.append(cat[beg_idx:end_idx].detach())

    return new_mems
           

這裡的

hids

是目前段每層的輸出,

mems

為目前段每層依賴的memory,

qlen

為序列長度,

mlen

為目前段依賴的memory的長度。

從代碼來看的話,前面的循環示意圖似乎有些問題?感覺在訓練階段,對于每個段裡面的第二個位置開始的點,都應該連到第一個位置連到的最前面memory?因為用的是同樣長度的memory。

五. 實驗結果

1. 語言模組化名額

在最關心的語言模型模組化名額上,論文比較了模型在單詞級别和字元級别上不同資料集的表現,并且與RNN和(vanilla) Transformer都做了比較。實驗證明,Transformer-XL在各個不同的資料集上均實作了目前的SoTA:在大型單詞級别資料集WikiText-103上,Transformer-XL将困惑度從20.5降到18.3;在enwiki8資料集上,12層Transformer-XL的bpc達到了1.06,相同bpc的AI-Rfou的模型( https://arxiv.org/abs/1808.04444 )參數量卻是6倍,24層Transformer-XL的bpc更是達到了0.99;在One Billion Word資料集上(僅具有短句的)和Penn Treebank資料集上(小型,僅有1M)也取得了SoTA的效果,前者的困惑度從23.7到21.8,後者的困惑度從55.3到54.5。表明了Transformer-XL在各個資料集下的不俗競争力。

2. 兩個創新點的優勢

下圖比較了不同上下文長度(即memory的長度)中包不包含循環機制、以及使不使用新位置編碼方式的困惑度得分。可見,使用循環機制和相對位置編碼的Transformer-XL明顯優于其他的模型,并且能夠有效利用長期依賴性,而且它能捕獲超出RNN 80%的依賴性,和超出Transformer 450%的依賴性。

Transformer-XL解讀(論文 + PyTorch源碼)

3. 測試階段的速度

Transformer-XL的推理速度也明顯快于vanilla Transformer,尤其是對于較長的上下文。比如,在上下文長度為800時,Transformer-XL提速363倍;而當上下文長度增加到3800時,Transformer-XL提速1874倍!

六. 總結

1. 模型特點

在 AI-Rfou 等人提出的vanilla Transformer上做了兩點創新:

  1. 引入循環機制(Recurrence Mechanism)
  2. 相對位置編碼(Relative Positional Encoding)

2. 優點

  1. 在幾種不同的資料集(大/小,字元級别/單詞級别等)均實作了最先進的語言模組化結果。
  2. 結合了深度學習的兩個重要概念——循環機制和注意力機制,允許模型學習長期依賴性,且可能可以擴充到需要該能力的其他深度學習領域,例如音頻分析(如每秒16k樣本的語音資料)等。
  3. 在inference階段非常快,比之前最先進的利用Transformer模型進行語言模組化的方法快300~1800倍。
  4. 有詳盡的源碼!含TensorFlow和PyTorch版本的,并且有TensorFlow預訓練好的模型及各個資料集上詳盡的超參數設定。

3. 不足

  1. 尚未在具體的NLP任務如情感分析、QA等上應用。
  2. 沒有給出與其他的基于Transformer的模型,如BERT等,對比有何優勢。
  3. 在Github源碼中提到,目前的sota結果是在TPU大叢集上訓練得出,對于我等渣機器黨就隻能玩玩base模式了。

傳送門

論文:https://arxiv.org/pdf/1901.02860.pdf

代碼:https://github.com/kimiyoung/transformer-xl

參考:https://www.lyrn.ai/2019/01/16/transformer-xl-sota-language-model

繼續閱讀