天天看點

HMM 隐馬爾科夫鍊 viterbi算法 Pytorch實作

  • 基于pytorch實作了維特比算法,用于求解隐馬爾科夫鍊的預測問題,由于重點在于了解算法并能用code表達出來,是以并未對于狀态轉移矩陣、發射機率矩陣和初始狀态機率進行估計。實際上A,B,Pi的估計比起viterbi算法本身難度要低很多,直接基于訓練資料,用類似矩估計的思想去統計頻率直接得到A,B,Pi即可。
  • 代碼中添加了很詳細的注釋,新人第一次發部落格,希望能夠幫助像我一樣的小白了解viterbi算法,畢竟無論是HMM還是CRF,viterbi都是其中的靈魂,而弄清楚viterbi算法也能夠讓自己的動态規劃思想更上一層樓(DP大神除外),以下是代碼部分
import torch

# 維特比算法解決HMM預測問題,給定狀态轉移矩陣A,發射機率矩陣B,初始狀态機率Pi
def viterbi(self, word_list, word2id, state2id, A, B, Pi):
    # 初始化viterbi矩陣,shape = [狀态數, 序列長度],viterbi[i, j]表示當序列的第j個元素的狀态為i時,前j個元素的觀測鍊機率的最大值
    # 初始化backpointer矩陣, shape = [狀态數, 序列長度],backpointer[i, j]表示當序列的第j個元素的狀态為i時,使得前j個元素的觀測鍊機率能夠達到最大值的第j-1個元素的狀态
    A, B, Pi = torch.log(A), torch.log(B), torch.log(Pi)
    N, seq_len = len(state2id), len(word_list)
    viterbi = torch.zeros(N, seq_len)
    backpointer = torch.zeros(N, seq_len)

    B_t = B.t()  # shape=[M, N],B[word_id]表示目前觀測為word_id時各狀态的機率
    start_word_id = word2id.get(word_list[0], None)
    if start_word_id is None:
        # 如果目前觀測不在詞表中,則假設其發射機率服從均勻分布
        b_t = torch.log(torch.ones(N) / N)
    else:
        b_t = B_t[start_word_id]

    viterbi[:, 0] = Pi + b_t  # 第一個元素的觀測機率 = 初始狀态機率 * 發射機率
    backpointer[:, 0] = -1  # start_word之前的元素并不存在狀态,故取-1

    for step in range(1, seq_len):
        word_id = word2id.get(word_list[step], None)
        if word_id is None:
            # 如果目前觀測不在詞表中,則假設其發射機率服從均勻分布
            b_t = torch.log(torch.ones(N) / N)
        else:
            b_t = B_t[word_id]
        # 計算第step個觀測元素的狀态為state時,前step個元素的觀測鍊機率的最大值,以及觀測鍊機率最大時,第step-1個元素的狀态
        for state_id in range(len(state2id)):
            # 前step步的最優路徑必定包含前step-1步的最優路徑,隻需要乘上狀态轉移機率和發射機率,然後求max即可
            # (由于發射機率對于特定的觀測與狀态是相同的,是以也可以先求max,再乘上發射機率)
            max_prob, best_state_id = torch.max(viterbi[:, step - 1] + A[:, state_id], dim=0)
            viterbi[state_id, step] = max_prob + b_t[state_id]
            backpointer[state_id, step] = best_state_id

    # 終止,并且從最後一個元素開始回溯
    max_prob, best_state_id = torch.max(viterbi[:, seq_len - 1], dim=0)
    # 反向儲存最優路徑
    best_path = [best_state_id.item()]
    for step in range(seq_len - 1, 0, -1):
        # backpointer[i, j]中存儲了第j-1到j個元素的最優路徑,即第j-1個元素的最優狀态
        best_state_id = backpointer[best_state_id, step]
        best_path.append(best_state_id.item())

    # 将state_id組成的逆序序列翻轉,并轉化為state
    assert len(best_path) == len(word_list)
    id2state = dict((id, state) for state, id in state2id.items())
    best_state_path = [id2state[id] for id in reversed(best_path)]

    return best_state_path, max_prob
           

繼續閱讀