天天看點

為什麼softmax值不适合用于評估LSTM聊天機器人輸出置信度

        使用LSTM建構聊天機器人,無論輸入是什麼,就算輸入完全不符合文法,模型都會給出一個輸出,顯然,這個輸出不是我們想要的,如何識别模型輸出是不是我們想要的?我們需要一種評估名額,評估模型輸出的置信度。那能不能使用模型最後一層softmax值做為置信度名額?分析 LSTM網絡模型圖,可以知道,LSTM模型的本質是通過訓練給定的語料集,找到合适的權重值,建立了一種映射關系,把輸入映射到輸出。對于語料集中列出的對話,模型建立了合适的權重值,輸入能正确的映射到輸出,對于語料集中沒有的對話,模型沒有得到訓練,自然沒有建立合适的權重值,模型的輸出是随機的,最後一層softmax值自然也是随機的,是以最後一層softmax值不能做為LSTM模型的置信度名額。

為什麼softmax值不适合用于評估LSTM聊天機器人輸出置信度
為什麼softmax值不适合用于評估LSTM聊天機器人輸出置信度

       下面對“最後一層softmax值不能做為LSTM模型的置信度名額”進行實測驗證,選取129條對話的小語料集訓練LSTM模型,進行了150輪訓練,精準度達到了1.0。從下面的測試結果圖可以看到,使用訓練集中的對話測試的softmax值分别是:0.999,1.0,0.992,非訓練集中的對話結果是:0.769,0.999,1.0,結果符合之前的判斷,是随機的,最後一層softmax值不能做為LSTM模型的置信度名額。

##模型計算softmax值代碼:
        decode_pred = sess.run(self.decoder_outputs_decode, input_feed)
        pred = decode_pred.sample_id

        if self.time_major:
            pred = tf.transpose(pred, (1, 0))

        pred = sess.run(pred)
        decode_rnn_output = decode_pred.rnn_output
        if self.time_major:
            decode_rnn_output = tf.transpose(decode_rnn_output,(1, 0, 2))

        decode_softmax = tf.nn.softmax(decode_rnn_output,axis=2)
        decode_softmax_val = sess.run(decode_softmax)
        code_id = int(pred[0][0])

        print("confidence:", decode_softmax_val[0][0][code_id])

           

測試結果:

使用訓練語料集之内的對話進行測試:

為什麼softmax值不适合用于評估LSTM聊天機器人輸出置信度

使用訓練語料集之外的對話進行測試:

為什麼softmax值不适合用于評估LSTM聊天機器人輸出置信度

繼續閱讀