Bahdanau Attention
基本的seq2seq attention就是Bahdanau Attention:
class BahdanauAttention(tf.keras.layers.Layer):
def __init__(self, units):
super(BahdanauAttention, self).__init__()
self.W_s = tf.keras.layers.Dense(units)
self.W_h = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self, dec_hidden, enc_output):
# query為上次的GRU隐藏層
# values為編碼器的編碼結果enc_output
hidden_with_time_axis = tf.expand_dims(query, 1)
score = self.V(tf.nn.tanh(self.W_s(enc_output) + self.W_h(hidden_with_time_axis)))
attention_weights = tf.nn.softmax(score, axis=1)
context_vector = attention_weights * values
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector,attention_weights
Coverage Attention
class BahdanauAttention(tf.keras.layers.Layer):
def __init__(self, units):
super(BahdanauAttention, self).__init__()
self.W_s = tf.keras.layers.Dense(units)
self.W_h = tf.keras.layers.Dense(units)
self.W_c = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self, dec_hidden, enc_output, enc_pad_mask, use_coverage, prev_coverage):
# query 隐藏層
# values為 編碼器的編碼結果enc_output
hidden_with_time_axis = tf.expand_dims(dec_hidden, 1)
# self.W_s(values) [batch_sz, max_len, units] self.W_h(hidden_with_time_axis) [batch_sz, 1, units]
# self.W_c(prev_coverage) [batch_sz, max_len, units] score [batch_sz, max_len, 1]
score = self.V(tf.nn.tanh(self.W_s(enc_output) + self.W_h(hidden_with_time_axis) + self.W_c(prev_coverage)))
attention_weights = tf.nn.softmax(score, axis=1)
# [batch_sz, max_len, enc_units]
context_vector = attention_weights * enc_output
# [batch_sz, enc_units]
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector,attention_weights
class BahdanauAttention(tf.keras.layers.Layer):
def __init__(self, units):
pass
def call(self, dec_hidden, enc_output, enc_pad_mask, use_coverage, prev_coverage):
if use_coverage and prev_coverage is not None:
pass
attention_weights = tf.nn.softmax(score, axis=1)
coverage = attention_weights + prev_coverage
else:
if use_coverage:
coverage = attention_weights