天天看點

seq2seq with attention代碼實作seq2seq with attention代碼實作

seq2seq with attention代碼實作

帶有注意力的seq2seq模型理論,請參考:seq2seq + attention 詳解

帶有Luong attention 的seq2seq模型實作如下:

# coding = utf-8
# author = 'xy'

"""
model2: encoder + attn + decoder
we use Bi-gru as our encoder, gru as decoder, Luong attention(concat method) as our attention
It refers to paper "Effective Approaches to Attention-based Neural Machine Translation"
"""

import numpy as np
import torch
from torch import nn
from torch.nn import functional as f
import test_helper


class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, embedding, num_layers=, dropout=):
        super(Encoder, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.embedding = embedding
        self.num_layers = num_layers
        self.drop_out = dropout

        self.rnn = nn.GRU(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            bidirectional=True
        )

    def forward(self, src, src_len):
        """
        :param src: tensor, cuda, (seq_len, batch_size)
        :param src_len: tensor, (batch_size)
        :return: outputs(seq_len, batch_size, hidden_size*2), h_t(num_layers, batch_size, hidden_size*2)
        """

        src = self.embedding(src)
        src = nn.utils.rnn.pack_padded_sequence(src, src_len)
        outputs, h_t = self.rnn(src, None)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        h_t = torch.cat((h_t[::], h_t[::]), dim=)
        return outputs, h_t


class Attn(nn.Module):
    def __init__(self, hidden_size):
        super(Attn, self).__init__()

        self.hidden_size = hidden_size

        self.fc1 = nn.Linear(hidden_size*, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.v = nn.Linear(hidden_size, )

    def forward(self, outputs, src_len, ss):
        """
        :param outputs: h tensor, (src_seq_len, batch_size, hidden_size*2)
        :param src_len: tensor, (batch_size)
        :param ss: s tensor, (tgt_seq_len, batch_size, hidden_size)
        :return: content tensor, (batch_size, tgt_seq_len, hidden_size*2)
        """

        src_seq_len = outputs.size()
        tgt_seq_len = ss.size()
        batch_size = outputs.size()

        h = outputs.view(-, self.hidden_size*)
        wh = self.fc1(h).view(src_seq_len, batch_size, self.hidden_size)
        wh = wh.transpose(, )
        wh = wh.unsqueeze()
        wh = wh.expand(batch_size, tgt_seq_len, src_seq_len, self.hidden_size)

        s = ss.view(-, self.hidden_size)
        ws = self.fc2(s).view(tgt_seq_len, batch_size, self.hidden_size)
        ws = ws.transpose(, )
        ws = ws.unsqueeze()
        ws = ws.expand(batch_size, tgt_seq_len, src_seq_len, self.hidden_size)

        hs = f.tanh(wh + ws)
        hs = hs.view(-, self.hidden_size)
        hs = self.v(hs).view(batch_size, tgt_seq_len, src_seq_len)

        # mask
        mask = []
        for i in src_len:
            i = i.item()
            mask.append([]*i + []*(src_seq_len-i))
        mask = torch.ByteTensor(mask)
        mask = mask.unsqueeze()
        mask = mask.expand(batch_size, tgt_seq_len, src_seq_len).cuda()
        hs.masked_fill_(mask, -float('inf'))

        hs = f.softmax(hs, dim=)
        hs = torch.bmm(hs, outputs.transpose(, ))
        return hs


class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, embedding, num_layers=, dropout=):
        super(Decoder, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.embedding = embedding
        self.num_layers = num_layers
        self.dropout = dropout

        self.rnn = nn.GRU(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            bidirectional=False
        )

        self.attn = Attn(hidden_size)
        self.fc = nn.Linear(hidden_size*, hidden_size)
        self.fh = nn.Linear(hidden_size, hidden_size)
        self.ws = nn.Linear(hidden_size, embedding.num_embeddings)

    def forward(self, tgt, state, outputs, src_len, teacher_forcing):
        """
        :param tgt: index, (tgt_seq_len, batch_size)
        :param state: s_t, (num_layers, batch_size, hidden_size)
        :param outputs: h, (src_seq_len, batch_size, hidden_size*2)
        :param src_len: tensor, (batch_size)
        :param teacher_forcing:
        :return: results(tgt_seq_len, batch_size, vocab_size), state(num_layers, batch_size, hidden_size)
        """
        flag = np.random.random() < teacher_forcing

        # teacher_forcing mode, also for testing mode
        if flag:
            embedded = self.embedding(tgt)
            ss, state = self.rnn(embedded, state)
            content = self.attn(outputs, src_len, ss).transpose(, )  # (tgt_seq_len, batch_size, hidden_size*2)
            content = content.contiguous().view(-, self.hidden_size*)
            ss = ss.view(-, self.hidden_size)
            result = f.tanh(self.fc(content) + self.fh(ss))
            result = self.ws(result).view(tgt.size(), tgt.size(), -)

        # generation mode
        else:
            result = []
            embedded = self.embedding(tgt[: ])
            for i in range(tgt.size()):
                ss, state = self.rnn(embedded, state)
                content = self.attn(outputs, src_len, ss).transpose(, )
                content = content.view(-, self.hidden_size*)
                ss = ss.view(-, self.hidden_size)
                r = f.tanh(self.fc(content) + self.fh(ss))
                r = self.ws(r).view(tgt.size(), -)
                result.append(r)

                _, topi = torch.topk(r, k=, dim=)
                embedded = self.embedding(topi.transpose(, ))
            result = torch.stack(result)

        return result, state


class Seq2seq(nn.Module):
    def __init__(self, input_size, hidden_size, embedding, num_layers=, dropout=):
        super(Seq2seq, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.embedding = embedding
        self.num_layers = num_layers
        self.dropout = dropout

        self.encoder = Encoder(
            input_size=input_size,
            hidden_size=hidden_size,
            embedding=embedding,
            num_layers=num_layers,
            dropout=dropout
        )

        self.decoder = Decoder(
            input_size=input_size,
            hidden_size=hidden_size,
            embedding=embedding,
            num_layers=num_layers,
            dropout=dropout
        )

        self.fc = nn.Linear(hidden_size*, hidden_size)

    def forward(self, src, src_len, tgt, teacher_forcing):
        """
        :param src: index, (src_seq_len, batch_size)
        :param src_len: tensor, (batch_size)
        :param tgt: index, (tgt_seq_len, batch_size)
        :param teacher_forcing:
        :return: outputs(tgt_seq_len, batch_size, vocab_size), state(num_layers, batch_size, hidden_size)
        """
        # encode
        outputs, h_t = self.encoder(src, src_len)
        state = h_t.view(-, self.hidden_size * )
        state = f.tanh(self.fc(state)).view(self.num_layers, -, self.hidden_size)

        # decode
        result, state = self.decoder(tgt, state, outputs, src_len, teacher_forcing)

        return result, state

    def gen(self, index, num_beams, max_len):
        """
        test mode
        :param index: a sample about src, tensor
        :param num_beams:
        :param max_len: max length of result
        :return: result, list
        """
        src = index.unsqueeze()
        src_len = torch.LongTensor([src.size()])

        # encode
        outputs, h_t = self.encoder(src, src_len)
        state = h_t.view(-, self.hidden_size * )
        state = f.tanh(self.fc(state)).view(self.num_layers, -, self.hidden_size)

        # decoder
        result = test_helper.beam_search(self.decoder, num_beams, max_len, state, outputs, src_len)
        if result[-] == :
            return result[: -]
        else:
            return result[:]

           

繼續閱讀