天天看點

Tensorflow Seq2Seq Decoder階段Helper的實作

BasicDecoder和dynamic_decode

為了簡單起見,從decode的入口dynamic_deocde函數開始分析:

dynamic_decode(
   decoder,
   output_time_major=False,
   impute_finished=False,
   maximum_iterations=None,
   parallel_iterations=32,
   swap_memory=False,
   scope=None
   )
   decoder: BasicDecoder、BeamSearchDecoder或者自己定義的decoder類對象
   output_time_major: Python boolean值, 為False時,以batch_size為主outputs傳回batch_size*time_step*...這種模式再計算的時候會添加額外的time; 為True時,outputs傳回time_step*batch_size*...,這種模式計算速度會更快
   impute_finished:  Python boolean值,為True時會拷貝标記為finished的batch實體的狀态并将輸出置零,這會導緻每個time step計算更慢,但是能確定最終狀态和輸出具有正确的值,使得程式運作更穩定,并忽略标記finished的time step
   maximum_iterations: 最大解碼步數,一般訓練設定為decoder_inputs_length,預測時設定一個想要的最大序列長度即可。程式會在産生<eos>或者到達最大步數處停止。
           

其實簡單來講dynamic_decode就是先執行decoder的初始化函數,對解碼時刻的state等變量進行初始化,然後循環執行decoder的step函數進行多輪解碼。簡而言之,其函數主體就相當于一個for循環,程式主體部分是一個control_flow_ops.while_loop循環:

其中cond是循環的條件,body是循環執行的主體, 這兩個都是函數,具體實作如下。loop_vars是要用到的變量,condition()和body()的參數相同且都是loop_vars。但一般condition()函數中隻用到個别參數用來判斷循環是否結束,大部分參數都是body中才會用到。parallel_iterations是并行執行循環的個數。condition()函數其實就是看finished是否全部為True,而body()函數也就是執行了

decoder.step(time, inputs, state)

這句代碼之後一系列的指派和判斷。

# 循環條件
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
                finished, unused_sequence_lengths):
    return math_ops.logical_not(math_ops.reduce_all(finished))

# 循環執行的主體
def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
    ## 1. 調用step函數得到下一時刻的輸出、狀态、并得到下一時刻輸入(由helper得到)和是否完成變量decoder_finished
  	(next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state)
    ## 2. 根據decoder_finished和time是否已經大于maximum_iterations綜合判斷解碼是否結束
  	next_finished = math_ops.logical_or(decoder_finished, finished)
  	if maximum_iterations is not None:
    	next_finished = math_ops.logical_or(
        	next_finished, time + 1 >= maximum_iterations)
  	next_sequence_lengths = array_ops.where(
      	math_ops.logical_and(math_ops.logical_not(finished), next_finished),
      	array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
      	sequence_lengths)

  	nest.assert_same_structure(state, decoder_state)
  	nest.assert_same_structure(outputs_ta, next_outputs)
  	nest.assert_same_structure(inputs, next_inputs),
    ## 3. 如果設定了impute_finished為真,在程式結束時将next_outputs置為零,不讓其進行反向傳播。并對decoder_state進行拷貝得到下一時刻狀态. 是以這裡如果設定為true,會浪費一些時間,但是精度會更高
  	if impute_finished:
    	emit = nest.map_structure(lambda out, zero: array_ops.where(finished, zero, out), 			next_outputs, zero_outputs)
  	else:
    	emit = next_outputs

    # Copy through states past finish
    def _maybe_copy_state(new, cur):
      	# TensorArrays and scalar states get passed through.
      	if isinstance(cur, tensor_array_ops.TensorArray):
        	pass_through = True
      	else:
        	new.set_shape(cur.shape)
        	pass_through = (new.shape.ndims == 0)
      	return new if pass_through else array_ops.where(finished, cur, new)
	## 4. 判斷輸入是否完成
    if impute_finished:
      	next_state = nest.map_structure(_maybe_copy_state, decoder_state, state)
    else:
      	next_state = decoder_state
    ## 5. 傳回結果outputs_ta
    outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit)
    return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)
# 調用上面定義的cond和body進行循環解碼
res = control_flow_ops.while_loop(condition, body,
      loop_vars=[initial_time, initial_outputs_ta, initial_state, initial_inputs, initial_finished, initial_sequence_lengths, ],
      parallel_iterations=parallel_iterations, swap_memory=swap_memory)
           

那麼decoder.step()函數究竟做了哪些工作呢?其實你可以把它了解為RNNCell.cell滾動了一次。隻不過考慮到解碼,會在此基礎上添加一些諸如使用helper得到輸出答案,并将其轉換為下一時刻輸入等操作。如下所示:

def step(self, time, inputs, state, name=None):
   	with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
     	cell_outputs, cell_state = self._cell(inputs, state)
     	if self._output_layer is not None:
        	# 如果設定了output層,将cell的輸出進行映射
        	cell_outputs = self._output_layer(cell_outputs)
     	# 根據輸出結果,選出想要的答案,比如說貪婪法選擇機率最大的單詞,Scheduled使用某種機率分布進行采樣等等
     	sample_ids = self._helper.sample(time=time, outputs=cell_outputs, state=cell_state)
    # 得到輸出結果将其轉化為下一時刻輸入。train的時候就是decoder_inputs的下一時刻,預測的時候将選出的單詞進行embedding即可
   (finished, next_inputs, next_state) = self._helper.next_inputs(time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids)
   outputs = BasicDecoderOutput(cell_outputs, sample_ids) # nameTulpe,将其一起作為outputs變量
   return (outputs, next_state, next_inputs, finished)
           

helper檔案的TrainingHelper和GreedyEmbeddingHelper以及CustomHelper

接下來我們就看一下不同的helper類的initialize,sample和next_inputs三個函數分别幹了什麼。

TrainingHelper

一般用于訓練階段Decoder解碼,輔助Decoder解碼過程

# 初始化finished以及initial_inputs
def initialize(self, name=None):
	with ops.name_scope(name, "TrainingHelperInitialize"):
	    finished = math_ops.equal(0, self._sequence_length)
	    all_finished = math_ops.reduce_all(finished)
	    next_inputs = control_flow_ops.cond(
		    all_finished, lambda: self._zero_inputs,
		    lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
	return (finished, next_inputs)

def sample(self, time, outputs, name=None, **unused_kwargs):
	with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
	 # 使用argmax函數取出outputs中的最大值
		sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32)
	return sample_ids

def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
	"""next_inputs_fn for TrainingHelper."""
	with ops.name_scope(name, "TrainingHelperNextInputs", [time, outputs, state]):
		next_time = time + 1
		# 再下一時刻的step小于decoder_sequence_length時,其bool值為False
		finished = (next_time >= self._sequence_length)
		all_finished = math_ops.reduce_all(finished)
		# 直接從decode_inputs中讀取下一個時刻的值作為下一時刻的解碼輸入
		def read_from_ta(inp):
			return inp.read(next_time)
		next_inputs = control_flow_ops.cond(
			all_finished, lambda: self._zero_inputs,
			lambda: nest.map_structure(read_from_ta, self._input_tas))
	return (finished, next_inputs, state)
           

GreedyEmbeddingHelper

一般用于預測階段的Decoder解碼,使用Greedy算法進行計算, 輔助Decoder解碼過程

# 初始化finished以及initial_inputs
def initialize(self, name=None):
	# 初始化 all False at the initial step
   	finished = array_ops.tile([False], [self._batch_size])
   	return (finished, self._start_inputs)
   	
def sample(self, time, outputs, state, name=None):
   	del time, state  # unused by sample_fn
   	if not isinstance(outputs, ops.Tensor):
     	raise TypeError("Expected outputs to be a single Tensor, got: %s" %type(outputs))
   	# 使用argmax函數取出outputs中的最大值
   	sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32)
   	return sample_ids

def next_inputs(self, time, outputs, state, sample_ids, name=None):
   	del time, outputs  # unused by next_inputs_fn
   	finished = math_ops.equal(sample_ids, self._end_token)
   	all_finished = math_ops.reduce_all(finished)
   	# 将sample_ids通過embedding_lookup(embedding, ids)轉化成下一時刻輸入的詞向量
   	next_inputs = control_flow_ops.cond(all_finished,
       		# If we're finished, the next_inputs value doesn't matter
       		lambda: self._start_inputs,
       		lambda: self._embedding_fn(sample_ids))
   return (finished, next_inputs, state)
           

CustomHelper

一般來說我們使用CustomHelper是為了解決Decoder階要使用上一個時間節點的輸出的需求,這就造成了不可能提前包裝好,即标準的動态rnn相當于: s i = f ( s i − 1 , x i ) s_i = f(s_{i-1}, x_i) si​=f(si−1​,xi​);但如果這個函數的參數需要擴充,比如我們做的: s i = f ( s i − 1 , y i − 1 , h i , c i ) s_i = f(s_{i-1}, y_{i-1}, h_i, c_i) si​=f(si−1​,yi−1​,hi​,ci​)。

于是我們需要Hack:使用tf.contrib.seq2seq.CustomHelper,傳入三個函數:

initial_fn():第一個時間點的輸入。

sample_fn():如何從logit到确定的某個固定的類别id。

next_inputs_fn():确定一般的時間點的輸入。

# 傳給CustomHelper的三個函數
# 初始化inputs和finished
def initial_fn():
    # all False at the initial step
    initial_finished = (0 >= self.decoder_seq_length)
    return (initial_finished, self.start_inputs)

def sample_fn(time, outputs, state):
    # del time, state  # unused by sample_fn
    # 使用argmax函數取出outputs中的最大值
    sample_ids = tf.cast(tf.argmax(outputs, axis=-1), dtype=tf.int32)
    return sample_ids

def next_inputs_fn(time, outputs, state, sample_ids):
    # 上一個時間節點上的輸出類别,擷取embedding再作為下一個時間節點的輸入
    next_input = tf.nn.embedding_lookup(decoder_embedding, sample_ids)
    time += 1 # next time 為輸入time + 1,否者會造成logits多出一個time step 
    # this operation produces boolean tensor of [batch_size]
    elements_finished = (time >= self.decoder_seq_length)
    # -> boolean scalar,标記整個batch已經結束
    all_finished = tf.reduce_all(elements_finished)  
    # If finished, the next_inputs value doesn't matter
    next_inputs = tf.cond(all_finished, lambda: self.start_inputs, lambda: next_input)
    return elements_finished, next_inputs, state

# 自定義helper使用
helper = CustomHelper(initial_fn, sample_fn, next_inputs_fn)
           

繼續閱讀