天天看点

Attention is all you need Transformer和Attention实现和注释

参考:跟着论文《 Attention is All You Need》一步一步实现Attention和Transformer

Attention is all you need Transformer和Attention实现和注释

对上面博客中提供的代码的一些细节进行注释。

由于是以机器翻译作为例子。对于没有接触过这方面的,特别是做视觉的会有很多细节不能理解,我花了一些时间,看了torchtext的使用以及机器翻译的过程,给代码做了写注释。

torchtext的使用:参考1,参考2,torchtext文档等等

代码分成两部分,一部分是NMT的部分,另一部分是模型

import numpy as np
import torch
import torch.nn as nn
import time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
#%matplotlib inline
from torchtext import data, datasets
from model import *



#用于mask数据,产生source mask和target mask
class Batch:
    """ 在训练期间使用mask处理数据 """

    def __init__(self, src, trg=None, pad=0):
        #src.size = batch_size, q_len
        self.src = src
        #src_mask.size = batch_size, 1, q_len
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:, :-1]
            self.trg_y = trg[:, 1:]
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum()

    @staticmethod
    def make_std_mask(tgt, pad):
        """ 创造一个mask来屏蔽补全词和字典外的词进行屏蔽"""
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        return tgt_mask

#将优化器再包一层,更方便
class NoamOpt:
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    def step(self):
        """ 更新参数和学习率 """
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        """ lrate 实现"""
        if step is None:
            step = self._step
        return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))


#没用到,就是返回一个优化器,里面是一些设置
def get_std_up(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
                   torch.optim.Adam(model.param_groups(),
                                    lr=0, betas=(0.9, 0.98), eps=1e-9))

'''size 是目标类别数目 smoothing这里使用,0.1'''
#平滑标签,将非真实目标的类别也给一个小的值
class LabelSmoothing(nn.Module):
    """ 标签平滑实现 """
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        #改成这样,不会有warning
        self.criterion = nn.KLDivLoss(reduction='none')
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None

    #x是generator的输出[n, vocab_size],也就是模型预测,target是真实目标,大小 n
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        #为什么减去2???? 要减去padding_idx和正确的label本身
        #size(x) = batch_size,
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0) #dim,index,val

        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))

class MultiGPULossCompute:
    "A multi-gpu loss compute and train function."

    def __init__(self, generator, criterion, devices, opt=None, chunk_size=5):
        # Send out to different gpus.
        self.generator = generator
        self.criterion = nn.parallel.replicate(criterion,
                                               devices=devices)
        self.opt = opt
        self.devices = devices
        self.chunk_size = chunk_size

    #size(out) = batch_size, max_len, d_model
    def __call__(self, out, targets, normalize):
        total = 0.0
        generator = nn.parallel.replicate(self.generator,
                                          devices=self.devices)
        out_scatter = nn.parallel.scatter(out,
                                          target_gpus=self.devices)
        out_grad = [[] for _ in out_scatter]
        targets = nn.parallel.scatter(targets,
                                      target_gpus=self.devices)
        # Divide generating into chunks.
        chunk_size = self.chunk_size
        for i in range(0, out_scatter[0].size(1), chunk_size):
            # Predict distributions
            out_column = [[Variable(o[:, i:i + chunk_size].data,
                              requires_grad=self.opt is not None)]
                               for o in out_scatter]

            gen = nn.parallel.parallel_apply(generator[:len(out_column)], out_column, )

            # Compute loss.
            y = [(g.contiguous().view(-1, g.size(-1)),
                  t[:, i:i + chunk_size].contiguous().view(-1))
                 for g, t in zip(gen, targets)]

            loss = nn.parallel.parallel_apply(self.criterion[:len(y)], y)

            # Sum and normalize loss
            l = nn.parallel.gather(loss,
                                   target_device=self.devices[0],dim=0)
            l = l.sum() / normalize
            total += l.data.item()

            #因为上面对数据进行分割,定义了一个新的out_column,所以梯度传到这里就不会往前传了,需要手动计算梯度,再从out往前传播
            # Backprop loss to output of transformer
            if self.opt is not None:
                l.backward() #累积
                # sh = out_column[0].detach().cpu()
                # input('here im  !!!!!!!!!!!')
                for j, l in enumerate(loss):#每个只有一项所以下标是0
                    out_grad[j].append(out_column[j][0].grad.data.clone())

        # Backprop all loss through transformer.
        if self.opt is not None:
            #把位于同一gpu的部分cat起来(每个对应不同的chunk)
            out_grad = [Variable(torch.cat(og, dim=1)) for og in out_grad]
            o1 = out
            o2 = nn.parallel.gather(out_grad,target_device=self.devices[0])
            o1.backward(gradient=o2)
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return total * normalize

#训练时相类似BucketIterator 的作用,测试时则则按正常顺序
#pool 起到预先读取100个batch的作用,将他们排序
#每个batch都对长度进行排序
class MyIterator(data.Iterator):
    def create_batches(self):
        if self.train:
            def pool(d, random_shuffler):
                for p in data.batch(d, self.batch_size * 100):
                    p_batch = data.batch(
                        sorted(p, key=self.sort_key),
                        self.batch_size, self.batch_size_fn)

                    for b in random_shuffler(list(p_batch)):  #about 100 times
                        #b为一个batch,list类型,它的元素是Example object,也就是一个训练样本
                        yield b

            self.batches = pool(self.data(), self.random_shuffler)

        else:
            self.batches = []
            for b in data.batch(self.data(), self.batch_size,
                                self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key))

#src_mask mask掉那些padding
#贪心解码,
def greedy_decode(model, src, src_mask, max_len, start_symbol) :
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)

    #根据当前的ys和src进行解码
    for i in range(max_len - 1):
        out = model.decode(memory, src_mask,
                          torch.Tensor(ys),
                          torch.Tensor(subsequent_mask(ys.size(1)).type_as(src.data)))
        #根据输出,用generator转换成各个词的概率(一个线性层和softmax)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys
global max_src_in_batch, max_tgt_in_batch


#这个函数是为了使用动态batch设置的。batch大小根据迭代器设置的batch大小和当前已经加进来的样本长度得到
#在这里就是相当于设置的batch_size为每个batch 占用单位空间上限
#计算需要的总空间,new指的是当前batch新样本,count指的是,这个样本是当前batch的第几个
#(new example to add, current count of examples in the batch, and current effective batch size)
#returns the new effective batch size resulting from adding that example to a batch
def batch_size_fn(new, count, sofar):
    """ 保持数据批量增加,并计算tokens+padding的总数 """
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    #new.src是当前训练样本的一个句子
    max_src_in_batch = max(max_src_in_batch, len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2) # 2 表示的是前后共2个标志?

    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

def  run_epoch(data_iter, pad_idx, model, loss_compute):
    """ 标准训练和日志函数 """
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_iter):
        batch = Batch(batch.src,batch.trg,pad_idx)
        out = model.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        #size(out) = [batch_size, q_len, d_model]
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 50 == 1:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss : %f Tokens per Sec: %f " % (i, loss/ batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    return total_loss / total_tokens

def main():

    '''读取数据集'''
    if True:
        #spacy 用来做分词
        import spacy
        spacy_de = spacy.load('de')
        spacy_en = spacy.load('en')
        def tokenize_de(text):
            return [tok.text for tok in spacy_de.tokenizer(text)]
        def tokenize_en(text):
            return [tok.text for tok in spacy_en.tokenizer(text)]

        #起始标志,终止标志和填充词 begin of sentence / end of sentence
        BOS_WORD = '<s>'
        EOS_WORD = '</S>'
        BLANK_WORD = "<blank>"
        #加上batch first就不用后面转置一下了,原来的没加
        SRC = data.Field(tokenize=tokenize_en, init_token=BOS_WORD, eos_token=EOS_WORD, pad_token=BLANK_WORD, batch_first=True)
        TGT = data.Field(tokenize=tokenize_de, pad_token=BLANK_WORD, batch_first = True)
        #长于MAX_LEN的丢掉
        MAX_LEN = 220
        #得到三个dataset
        train, val, test = datasets.IWSLT.splits(
            exts=('.en', '.de'), fields=(SRC, TGT),
            filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and len(vars(x)['trg']) <= MAX_LEN)
        #词出现频率小于MIN_FREQ的丢掉
        MIN_FREQ = 2
        MIN_FREQ = 2
        SRC.build_vocab(train.src, min_freq=MIN_FREQ)
        TGT.build_vocab(train.trg, min_freq=MIN_FREQ)

    # 需要使用的GPU
    device_ids = [0,1] # 如果只有一个GPU,使用devices=[0]
    device = torch.device(0)
    '''构建模型'''
    if True:
        #pad_idx 一般好像都是0
        pad_idx = TGT.vocab.stoi["<blank>"]
        #模型在前面,这里先不管,N是encoder和decoder的层数
        model = make_model(len(SRC.vocab), len(TGT.vocab), N=6)
        model.cuda()
        #标签平滑,这个也先不管
        criterion = LabelSmoothing(size=len(TGT.vocab), padding_idx=pad_idx, smoothing=0.1)
        criterion.cuda()
        #这个batch_size不是句子数,输入占用空间大小的数目,见batch_size_fn这个函数
        BATCH_SIZE = 1200
        #自定义Iterator。 repeat应该是表示同一个迭代顺序要不要repeat多个epoch
        train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=device, repeat=False,
                                sort_key=lambda x : (len(x.src), len(x.trg)),
                               batch_size_fn=batch_size_fn, train=True)
        valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device= device, repeat=False,
                               sort_key=lambda x : (len(x.src), len(x.trg)),
                               batch_size_fn=batch_size_fn, train=False)
        model_par = nn.DataParallel(model, device_ids=device_ids)

    # 这里需要很大的内存,报内存错误很正常,可以直接用下面训练好的
    # 或者调小BATCH_SIZE
    '''开始训练'''
    if True:
        #优化器包装
        model_opt = NoamOpt(model.src_embed[0].d_model, 1, 2000,
                            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
        for epoch in range(10):
            model_par.train()
            run_epoch(train_iter, pad_idx, model_par,
                      MultiGPULossCompute(model.generator, criterion, devices=device_ids, opt=model_opt))
            model_par.eval()
            loss = run_epoch(valid_iter,pad_idx, model_par,
                             MultiGPULossCompute(model.generator, criterion, devices=device_ids, opt=None))
            print("loss is: %f" % loss)
    else: #load 已保存的模型
        model = torch.load("iwslt.pt")

    for i, batch in enumerate(valid_iter):
        #取一条句子,大小为seq_len
        src = batch.src[0]
        #size(src_mask) = 1, 1, seq_len????
        src_mask = (src != SRC.vocab.stoi["<blank>"]).unsqueeze(-2)
        print('src_mask_size=',src_mask.size())

        #out 是每个词在词典中的位置,还要转换成目标语单词
        out = greedy_decode(model, src, src_mask, max_len=60, start_symbol=TGT.vocab.stoi["<s>"])
        print("Translation: ", end="\t")
        #输出模型的翻译
        for i in range(1, out.size(1)):
            sym = TGT.vocab.itos[out[0, i]]
            if sym == "</s>":
                print('meet end------------------')
                break
            print(sym, end=" ")
        print()
        #输出真实的目标 ground true
        print("Target:", end="\t")
        for i in range(1, batch.trg.size(0)):
            sym = TGT.vocab.itos[batch.trg.data[i, 0]]
            if sym == '</s>':
                break
            print(sym, end=" ")
        print()
        break

if __name__ == '__main__':
     main()
           
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import  math, copy

#封装整个encoder和decoder
class EncoderDecoder(nn.Module):
    """
    A stanard Encoder-Decoder architecture.Base fro this and many other models.
    """

    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        """ Take in and process masked src and target sequences. """
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

#一个分类层,把d_model转换成对应每个word的概率
#因为用的是KLDivLoss,所以这里输出log_softmax,把KLDivLoss改成CrossEntropyLoss,这里就直接输出logits即可
class Generator(nn.Module):
    """Define standard linear + softmax generation step."""

    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

#用于复制N个 module
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

#将N个enconder layer封装起来
class Encoder(nn.Module):
    "Core encoder is a stack of N layers"

    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

#layer norm, pytorch里面已经有了
class LayerNorm(nn.Module):
    """ Construct a layernorm model (See citation for details)"""

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

#layernorm + sublayer + residual
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm. Note for
    code simplicity the norm is first as opposed to last .
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        """Apply residual connection to any sublayer with the sanme size. """
        return x + self.dropout(sublayer(self.norm(x)))

#encoder层 attention + poitwise_feedword层
class EncoderLayer(nn.Module):
    """Encoder is made up of self-attention and feed forward (defined below)"""

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        """Follow Figure 1 (left) for connection """
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

#decoder封装decoder层,memory是encoder的输出
#图的右边部分
class Decoder(nn.Module):
    """Generic N layer decoder with masking """

    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)


#两个attention(layernorm + resisual connection) + poitwise_feedward

class DecoderLayer(nn.Module):
    """Decoder is made of self-attn, src-attn, and feed forward (defined below)"""

    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        """Follow Figure 1 (right) for connections"""
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

'''
mask 例子 shape = (1,5,5)
[[[1 0 0 0 0 0],
  [1 1 0 0 0 0],
  [1 1 1 0 0 0],
  [1 1 1 1 0 0],
  [1 1 1 1 1 0],
]]
'''
def subsequent_mask(size):
    """Mask out subsequent positions. """
    attn_shape = (1, size, size)
    #k=1,对角线也是0
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

#用在multi-Head attention中
#Scaled Dot-Product Attention
def attention(query, key, value, mask=None, dropout=None):
    """Compute 'Scaled Dot Product Attention ' """
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # matmul矩阵相乘
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    p_attn = F.softmax(scores, dim = -1)

    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

#head数目h要整除d_model
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        """ Take in model size and numbe of heads """
        super(MultiHeadedAttention, self).__init__()7
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        """图片ModalNet-20的实现"""
        if mask is not None:
            # 同样的mask应用到所有heads
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1. 批量做linear投影 => h x d_k
        # query, key, value分别经过一h个线性变换(整合成一个)
        query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linears, (query, key, value))]

        # 2. 批量应用attention机制在所有的投影向量上
        #attn 没有用到
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3. 使用view进行“Concat”并且进行最后一层的linear
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)


#除了Attention子层之外,Encoder和Decoder中的每个层都包含一个全连接前馈网络,
# 分别地应用于每个位置(每个word)。其中包括两个线性变换,然后使用ReLU作为激活函数。相当于两层1*1卷积,每个位置的特征就是对应一个channel,1
class PositionwiseFeedForward(nn.Module):
    """
    FFN实现
    d_model = 512
    d_ff = 2048
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model) #look up matrix
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)


class PositionalEncoding(nn.Module):
    """PE函数实现"""
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-(math.log(10000.0) / d_model)))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        pe.requires_grad = False
        self.register_buffer('pe', pe)

    def forward(self, x):
        #print(x.type(),self.pe.type())
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


def make_model(src_vacab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    """ 构建模型"""
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vacab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab)
    )

    # !!!import for the work
    # 使用Glorot/ fan_avg初始化参数
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model
           

继续阅读