使用LSTM建構聊天機器人,無論輸入是什麼,就算輸入完全不符合文法,模型都會給出一個輸出,顯然,這個輸出不是我們想要的,如何識别模型輸出是不是我們想要的?我們需要一種評估名額,評估模型輸出的置信度。那能不能使用模型最後一層softmax值做為置信度名額?分析 LSTM網絡模型圖,可以知道,LSTM模型的本質是通過訓練給定的語料集,找到合适的權重值,建立了一種映射關系,把輸入映射到輸出。對于語料集中列出的對話,模型建立了合适的權重值,輸入能正确的映射到輸出,對于語料集中沒有的對話,模型沒有得到訓練,自然沒有建立合适的權重值,模型的輸出是随機的,最後一層softmax值自然也是随機的,是以最後一層softmax值不能做為LSTM模型的置信度名額。
下面對“最後一層softmax值不能做為LSTM模型的置信度名額”進行實測驗證,選取129條對話的小語料集訓練LSTM模型,進行了150輪訓練,精準度達到了1.0。從下面的測試結果圖可以看到,使用訓練集中的對話測試的softmax值分别是:0.999,1.0,0.992,非訓練集中的對話結果是:0.769,0.999,1.0,結果符合之前的判斷,是随機的,最後一層softmax值不能做為LSTM模型的置信度名額。
##模型計算softmax值代碼:
decode_pred = sess.run(self.decoder_outputs_decode, input_feed)
pred = decode_pred.sample_id
if self.time_major:
pred = tf.transpose(pred, (1, 0))
pred = sess.run(pred)
decode_rnn_output = decode_pred.rnn_output
if self.time_major:
decode_rnn_output = tf.transpose(decode_rnn_output,(1, 0, 2))
decode_softmax = tf.nn.softmax(decode_rnn_output,axis=2)
decode_softmax_val = sess.run(decode_softmax)
code_id = int(pred[0][0])
print("confidence:", decode_softmax_val[0][0][code_id])
測試結果:
使用訓練語料集之内的對話進行測試:
使用訓練語料集之外的對話進行測試: