天天看点

【自然语言处理】tf.contrib.seq2seq.GreedyEmbeddingHelper源码解析

前言

本文衔接TrainingHelper,也可以衔接BasicDecoder。先说明一下,GreedyEmbeddingHelper主要作用是接收开始符,然后生成指定长度大小的句子。

正文

GreedyEmbeddingHelper代码传送门

class GreedyEmbeddingHelper(Helper):
  """A helper for use during inference.
  Uses the argmax of the output (treated as logits) and passes the
  result through an embedding layer to get the next input.
  """

  def __init__(self, embedding, start_tokens, end_token):
    """Initializer.
    Args:
      embedding: A callable that takes a vector tensor of `ids` (argmax ids),
        or the `params` argument for `embedding_lookup`. The returned tensor
        will be passed to the decoder input.
      start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
      end_token: `int32` scalar, the token that marks end of decoding.
    Raises:
      ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a
        scalar.
    """
    if callable(embedding):
      self._embedding_fn = embedding
    else:
      self._embedding_fn = (
          lambda ids: embedding_ops.embedding_lookup(embedding, ids))

    self._start_tokens = ops.convert_to_tensor(
        start_tokens, dtype=dtypes.int32, name="start_tokens")
    self._end_token = ops.convert_to_tensor(
        end_token, dtype=dtypes.int32, name="end_token")
    if self._start_tokens.get_shape().ndims != 1:
      raise ValueError("start_tokens must be a vector")
    self._batch_size = array_ops.size(start_tokens)
    if self._end_token.get_shape().ndims != 0:
      raise ValueError("end_token must be a scalar")
    self._start_inputs = self._embedding_fn(self._start_tokens)
           

在GreedyEmbeddingHelper初始阶段,接收一个embedding矩阵,以便后面的embedding_lookup。可以注意到在TrainingHelper并不需要这个,是因为在训练阶段,我们给TrainingHelper的就是[batch_size, seq_len, embed_size]的输入,已经是词向量了。而在推理阶段,我们只给了一个开始符,给了我们需要的句子长度,所以我们在输出一个词的时候还需要进行embedding_lookup成词向量作为下一个时刻的输入。

def initialize(self, name=None):
    finished = array_ops.tile([False], [self._batch_size])
    return (finished, self._start_inputs)
           

第一个输入,在TrainingHelper的第一个输入是inputs[0],而这里的第一个输入是开始符向量(注意开始符是一个[batch_size]的向量,里面的元素不一定都一样。因为有时候我们可能在生成到一半的句子中才开始推理,这时候的第一个开始符生成一半句子的最后一个词)。当然,这里的finished肯定是都是False的。

def sample(self, time, outputs, state, name=None):
    """sample for GreedyEmbeddingHelper."""
    del time, state  # unused by sample_fn
    # Outputs are logits, use argmax to get the most probable id
    if not isinstance(outputs, ops.Tensor):
      raise TypeError("Expected outputs to be a single Tensor, got: %s" %
                      type(outputs))
    sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
    return sample_ids
           

这里是采样的意思,判断一个词根据什么情况来在这里,Greedy是贪婪的意思,也就是这个采样遵循贪心算法,选取最大概率输出对应词作为采样的词。

def next_inputs(self, time, outputs, state, sample_ids, name=None):
    """next_inputs_fn for GreedyEmbeddingHelper."""
    del time, outputs  # unused by next_inputs_fn
    finished = math_ops.equal(sample_ids, self._end_token)
    all_finished = math_ops.reduce_all(finished)
    next_inputs = control_flow_ops.cond(
        all_finished,
        # If we're finished, the next_inputs value doesn't matter
        lambda: self._start_inputs,
        lambda: self._embedding_fn(sample_ids))
    return (finished, next_inputs, state)
           

但是GreedyEmbeddingHelper其实也关注next_inputs,因为上一个采样的词需要当成当前的输入。

总结

Helper类型很多,SampleEmbeddingHelper,CustomHelper,ScheduledEmbeddingTrainingHelper,ScheduledOutputTrainingHelper,InferenceHelper,其实大多大同小异,学会了训练阶段的Helper和推理阶段的Helper的典型,也就是上面两个,就可以触类旁通。

全部的代码在Helper.py这里,有需要延伸的可以继续看看。

继续阅读