seq2seq with attention代碼實作
帶有注意力的seq2seq模型理論,請參考:seq2seq + attention 詳解
帶有Luong attention 的seq2seq模型實作如下:
# coding = utf-8
# author = 'xy'
"""
model2: encoder + attn + decoder
we use Bi-gru as our encoder, gru as decoder, Luong attention(concat method) as our attention
It refers to paper "Effective Approaches to Attention-based Neural Machine Translation"
"""
import numpy as np
import torch
from torch import nn
from torch.nn import functional as f
import test_helper
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, embedding, num_layers=, dropout=):
super(Encoder, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.embedding = embedding
self.num_layers = num_layers
self.drop_out = dropout
self.rnn = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=True
)
def forward(self, src, src_len):
"""
:param src: tensor, cuda, (seq_len, batch_size)
:param src_len: tensor, (batch_size)
:return: outputs(seq_len, batch_size, hidden_size*2), h_t(num_layers, batch_size, hidden_size*2)
"""
src = self.embedding(src)
src = nn.utils.rnn.pack_padded_sequence(src, src_len)
outputs, h_t = self.rnn(src, None)
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
h_t = torch.cat((h_t[::], h_t[::]), dim=)
return outputs, h_t
class Attn(nn.Module):
def __init__(self, hidden_size):
super(Attn, self).__init__()
self.hidden_size = hidden_size
self.fc1 = nn.Linear(hidden_size*, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.v = nn.Linear(hidden_size, )
def forward(self, outputs, src_len, ss):
"""
:param outputs: h tensor, (src_seq_len, batch_size, hidden_size*2)
:param src_len: tensor, (batch_size)
:param ss: s tensor, (tgt_seq_len, batch_size, hidden_size)
:return: content tensor, (batch_size, tgt_seq_len, hidden_size*2)
"""
src_seq_len = outputs.size()
tgt_seq_len = ss.size()
batch_size = outputs.size()
h = outputs.view(-, self.hidden_size*)
wh = self.fc1(h).view(src_seq_len, batch_size, self.hidden_size)
wh = wh.transpose(, )
wh = wh.unsqueeze()
wh = wh.expand(batch_size, tgt_seq_len, src_seq_len, self.hidden_size)
s = ss.view(-, self.hidden_size)
ws = self.fc2(s).view(tgt_seq_len, batch_size, self.hidden_size)
ws = ws.transpose(, )
ws = ws.unsqueeze()
ws = ws.expand(batch_size, tgt_seq_len, src_seq_len, self.hidden_size)
hs = f.tanh(wh + ws)
hs = hs.view(-, self.hidden_size)
hs = self.v(hs).view(batch_size, tgt_seq_len, src_seq_len)
# mask
mask = []
for i in src_len:
i = i.item()
mask.append([]*i + []*(src_seq_len-i))
mask = torch.ByteTensor(mask)
mask = mask.unsqueeze()
mask = mask.expand(batch_size, tgt_seq_len, src_seq_len).cuda()
hs.masked_fill_(mask, -float('inf'))
hs = f.softmax(hs, dim=)
hs = torch.bmm(hs, outputs.transpose(, ))
return hs
class Decoder(nn.Module):
def __init__(self, input_size, hidden_size, embedding, num_layers=, dropout=):
super(Decoder, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.embedding = embedding
self.num_layers = num_layers
self.dropout = dropout
self.rnn = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=False
)
self.attn = Attn(hidden_size)
self.fc = nn.Linear(hidden_size*, hidden_size)
self.fh = nn.Linear(hidden_size, hidden_size)
self.ws = nn.Linear(hidden_size, embedding.num_embeddings)
def forward(self, tgt, state, outputs, src_len, teacher_forcing):
"""
:param tgt: index, (tgt_seq_len, batch_size)
:param state: s_t, (num_layers, batch_size, hidden_size)
:param outputs: h, (src_seq_len, batch_size, hidden_size*2)
:param src_len: tensor, (batch_size)
:param teacher_forcing:
:return: results(tgt_seq_len, batch_size, vocab_size), state(num_layers, batch_size, hidden_size)
"""
flag = np.random.random() < teacher_forcing
# teacher_forcing mode, also for testing mode
if flag:
embedded = self.embedding(tgt)
ss, state = self.rnn(embedded, state)
content = self.attn(outputs, src_len, ss).transpose(, ) # (tgt_seq_len, batch_size, hidden_size*2)
content = content.contiguous().view(-, self.hidden_size*)
ss = ss.view(-, self.hidden_size)
result = f.tanh(self.fc(content) + self.fh(ss))
result = self.ws(result).view(tgt.size(), tgt.size(), -)
# generation mode
else:
result = []
embedded = self.embedding(tgt[: ])
for i in range(tgt.size()):
ss, state = self.rnn(embedded, state)
content = self.attn(outputs, src_len, ss).transpose(, )
content = content.view(-, self.hidden_size*)
ss = ss.view(-, self.hidden_size)
r = f.tanh(self.fc(content) + self.fh(ss))
r = self.ws(r).view(tgt.size(), -)
result.append(r)
_, topi = torch.topk(r, k=, dim=)
embedded = self.embedding(topi.transpose(, ))
result = torch.stack(result)
return result, state
class Seq2seq(nn.Module):
def __init__(self, input_size, hidden_size, embedding, num_layers=, dropout=):
super(Seq2seq, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.embedding = embedding
self.num_layers = num_layers
self.dropout = dropout
self.encoder = Encoder(
input_size=input_size,
hidden_size=hidden_size,
embedding=embedding,
num_layers=num_layers,
dropout=dropout
)
self.decoder = Decoder(
input_size=input_size,
hidden_size=hidden_size,
embedding=embedding,
num_layers=num_layers,
dropout=dropout
)
self.fc = nn.Linear(hidden_size*, hidden_size)
def forward(self, src, src_len, tgt, teacher_forcing):
"""
:param src: index, (src_seq_len, batch_size)
:param src_len: tensor, (batch_size)
:param tgt: index, (tgt_seq_len, batch_size)
:param teacher_forcing:
:return: outputs(tgt_seq_len, batch_size, vocab_size), state(num_layers, batch_size, hidden_size)
"""
# encode
outputs, h_t = self.encoder(src, src_len)
state = h_t.view(-, self.hidden_size * )
state = f.tanh(self.fc(state)).view(self.num_layers, -, self.hidden_size)
# decode
result, state = self.decoder(tgt, state, outputs, src_len, teacher_forcing)
return result, state
def gen(self, index, num_beams, max_len):
"""
test mode
:param index: a sample about src, tensor
:param num_beams:
:param max_len: max length of result
:return: result, list
"""
src = index.unsqueeze()
src_len = torch.LongTensor([src.size()])
# encode
outputs, h_t = self.encoder(src, src_len)
state = h_t.view(-, self.hidden_size * )
state = f.tanh(self.fc(state)).view(self.num_layers, -, self.hidden_size)
# decoder
result = test_helper.beam_search(self.decoder, num_beams, max_len, state, outputs, src_len)
if result[-] == :
return result[: -]
else:
return result[:]