天天看點

tensorflow 使用HMM的 viterbi 計算誤差

# observation:[batch_size,num_step,output_dims] 神經網絡輸出
# transition:[output_dims,output_dims] 轉移矩陣
# pi:[batch_size,output_dims] 初始機率矩陣
def lstm_crf_viterbi(observation,transition,pi):
    batch_size = observation.shape[].value
    num_step = observation.shape[].value
    output_len = transition.shape[].value
    previous = []   # [B,O]
    #記錄最終路徑
    all_path_tag_sequence = []
    batch_scores = []
    #記錄最佳路徑
    batch_argmax = [[] for b in xrange(batch_size)]
    for b in xrange(batch_size):
        previous.append(tf.transpose([observation[b][]+pi[b]]))
    for b in xrange(batch_size):
        for x in  range(,num_step): 
            r_pre =tf.transpose(tf.convert_to_tensor([previous[b] for i in range(output_len)]))
            r_obs = tf.convert_to_tensor([observation[b][x] for i in range(output_len)])
            scores = r_pre + transition + r_obs
            scores = tf.convert_to_tensor(scores)
            batch_argmax[b].append(tf.squeeze(tf.argmax(scores,)))
            previous[b] = tf.reduce_max(scores,) 
            previous[b] = tf.squeeze(previous[b])
    print(batch_argmax)
    #回溯 (僅最高分)
    for b in xrange(batch_size):
        best_path = [tf.argmax(previous[b])]
        for x in xrange(num_step-,-,-):
            best_path.insert(,batch_argmax[b][x][best_path[]])
        all_path_tag_sequence.append(best_path)
    return previous,all_path_tag_sequence
    #previous:最高分
    #all_path_tag_sequence:最高分路徑
           

繼續閱讀