天天看點

TensorFlow softmax VS sparse softmax

sparse_softmax_cross_entropy_with_logits VS softmax_cross_entropy_with_logits

這兩者都是計算分類問題的softmax loss的,是以兩者的輸出應該是一樣的,唯一差別是兩者的labels輸入形似不一樣。

Difference

在tensorflow中使用softmax loss的時候,會發現有兩個softmax cross entropy。剛開始很難看出什麼差别,結合程式看的時候,就很容易能看出兩者差異。總的來說兩者都是計算分類問題的softmax交叉熵損失,而兩者使用的标簽真值的形式不同。

  • sparse_softmax_cross_entropy_with_logits:

    使用的是實數來表示類别,資料類型為int16,int32,或者 int64,标簽大小範圍為[0,num_classes-1],标簽的次元為[batch_size]大小。

  • softmax_cross_entropy_with_logits:

    使用的是one-hot二進制碼來表示類别,資料類型為float16,float32,或者float64,次元為[batch_size, num_classes]。這裡需要說明一下的時,标簽資料類型并不是Bool型的。這是因為實際上在tensorflow中,softmax_cross_entropy_with_logits中的每一個類别是一個機率分布,tensorflow中對該子產品的說明中明确指出了這一點,Each row labels[i] must be a valid probability distribution。很顯然,one-hot的二進碼也可以看是一個有效的機率分布。

另外stackoverflow上面對兩者的差別有一個總結說得很清楚,可以參考一下。

Common

有一點需要注意的是,softmax_cross_entropy_with_logits和sparse_softmax_cross_entropy_with_logits中的輸入都需要unscaled logits,因為tensorflow内部機制會将其進行歸一化操作以提高效率,什麼意思呢?就是說計算loss的時候,不要将輸出的類别值進行softmax歸一化操作,輸入就是 wT∗X+b 的結果。

tensorflow的說明是這樣的:

Warning: This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results.

至于為什麼這樣可以提高效率,簡單地說就是把unscaled digits輸入到softmax loss中在反向傳播計算倒數時計算量更少,感興趣的可以參考pluskid大神的部落格Softmax vs. Softmax-Loss: Numerical Stability,博文裡面講得非常清楚了。另外說一下,看了大神的博文,不得不說大神思考問題和解決問題的能力真的很強!

Example

import tensorflow as tf
#batch_size = 
labels = tf.constant([[0, 0, 0, 1],[0, 1, 0, 0]])
logits = tf.constant([[-3.4, 2.5, -1.2, 5.5],[-3.4, 2.5, -1.2, 5.5]])

loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
loss_s = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(labels,), logits=logits)

with tf.Session() as sess:  
    print "softmax loss:", sess.run(loss)
    print "sparse softmax loss:", sess.run(loss_s)
           

Output:

softmax loss: [ 0.04988896 3.04988885]

sparse softmax loss: [ 0.04988896 3.04988885]

Reference

tensorflow:softmax_cross_entropy_with_logits

tensorflow:sparse_softmax_cross_entropy_with_logits

stackoverflow

Softmax vs. Softmax-Loss: Numerical Stability

繼續閱讀