sparse_softmax_cross_entropy_with_logits VS softmax_cross_entropy_with_logits
這兩者都是計算分類問題的softmax loss的,是以兩者的輸出應該是一樣的,唯一差別是兩者的labels輸入形似不一樣。
Difference
在tensorflow中使用softmax loss的時候,會發現有兩個softmax cross entropy。剛開始很難看出什麼差别,結合程式看的時候,就很容易能看出兩者差異。總的來說兩者都是計算分類問題的softmax交叉熵損失,而兩者使用的标簽真值的形式不同。
-
sparse_softmax_cross_entropy_with_logits:
使用的是實數來表示類别,資料類型為int16,int32,或者 int64,标簽大小範圍為[0,num_classes-1],标簽的次元為[batch_size]大小。
-
softmax_cross_entropy_with_logits:
使用的是one-hot二進制碼來表示類别,資料類型為float16,float32,或者float64,次元為[batch_size, num_classes]。這裡需要說明一下的時,标簽資料類型并不是Bool型的。這是因為實際上在tensorflow中,softmax_cross_entropy_with_logits中的每一個類别是一個機率分布,tensorflow中對該子產品的說明中明确指出了這一點,Each row labels[i] must be a valid probability distribution。很顯然,one-hot的二進碼也可以看是一個有效的機率分布。
另外stackoverflow上面對兩者的差別有一個總結說得很清楚,可以參考一下。
Common
有一點需要注意的是,softmax_cross_entropy_with_logits和sparse_softmax_cross_entropy_with_logits中的輸入都需要unscaled logits,因為tensorflow内部機制會将其進行歸一化操作以提高效率,什麼意思呢?就是說計算loss的時候,不要将輸出的類别值進行softmax歸一化操作,輸入就是 wT∗X+b 的結果。
tensorflow的說明是這樣的:
Warning: This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results.
至于為什麼這樣可以提高效率,簡單地說就是把unscaled digits輸入到softmax loss中在反向傳播計算倒數時計算量更少,感興趣的可以參考pluskid大神的部落格Softmax vs. Softmax-Loss: Numerical Stability,博文裡面講得非常清楚了。另外說一下,看了大神的博文,不得不說大神思考問題和解決問題的能力真的很強!
Example
import tensorflow as tf
#batch_size =
labels = tf.constant([[0, 0, 0, 1],[0, 1, 0, 0]])
logits = tf.constant([[-3.4, 2.5, -1.2, 5.5],[-3.4, 2.5, -1.2, 5.5]])
loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
loss_s = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(labels,), logits=logits)
with tf.Session() as sess:
print "softmax loss:", sess.run(loss)
print "sparse softmax loss:", sess.run(loss_s)
Output:
softmax loss: [ 0.04988896 3.04988885]
sparse softmax loss: [ 0.04988896 3.04988885]
Reference
tensorflow:softmax_cross_entropy_with_logits
tensorflow:sparse_softmax_cross_entropy_with_logits
stackoverflow
Softmax vs. Softmax-Loss: Numerical Stability