天天看點

交叉熵損失函數原理詳解交叉熵損失函數原理詳解

交叉熵損失函數原理詳解

之前在代碼中經常看見交叉熵損失函數(CrossEntropy Loss),隻知道它是分類問題中經常使用的一種損失函數,對于其内部的原理總是模模糊糊,而且一般使用交叉熵作為損失函數時,在模型的輸出層總會接一個softmax函數,至于為什麼要怎麼做也是不懂,是以專門花了一些時間打算從原理入手,搞懂它,故在此寫一篇部落格進行總結,以便以後翻閱。

交叉熵簡介

交叉熵是資訊論中的一個重要概念,主要用于度量兩個機率分布間的差異性,要了解交叉熵,需要先了解下面幾個概念。

資訊量

資訊創始者香農(Shannon)認為“資訊是用來消除随機不确定性的東西”,也就是說衡量資訊量的大小就是看這個資訊消除不确定性的程度。

“太陽從東邊升起”,這條資訊并沒有減少不确定性,因為太陽肯定是從東邊升起的,這是一句廢話,資訊量為0。

”2018年中國隊成功進入世界杯“,從直覺上來看,這句話具有很大的資訊量。因為中國隊進入世界杯的不确定性因素很大,而這句話消除了進入世界杯的不确定性,是以按照定義,這句話的資訊量很大。

根據上述可總結如下:資訊量的大小與資訊發生的機率成反比。機率越大,資訊量越小。機率越小,資訊量越大。

設某一事件發生的機率為P(x),其資訊量表示為:

I ( x ) = − log ⁡ ( P ( x ) ) I\left ( x \right ) = -\log\left ( P\left ( x \right ) \right ) I(x)=−log(P(x))

其中 I ( x ) I\left ( x \right ) I(x)表示資訊量,這裡 log ⁡ \log log表示以e為底的自然對數。

資訊熵

資訊熵也被稱為熵,用來表示所有資訊量的期望。

期望是試驗中每次可能結果的機率乘以其結果的總和。

是以資訊量的熵可表示為:(這裡的 X X X是一個離散型随機變量)

H ( X ) = − ∑ i = 1 n P ( x i ) log ⁡ ( P ( x i ) ) ) ( X = x 1 , x 2 , x 3 . . . , x n ) H\left ( \mathbf{X} \right ) = -\sum \limits_{i=1}^n P(x_{i}) \log \left ( P \left ( x_{i} \right ))) \qquad ( \mathbf{X}= x_{1},x_{2},x_{3}...,x_{n} \right) H(X)=−i=1∑n​P(xi​)log(P(xi​)))(X=x1​,x2​,x3​...,xn​)

使用明天的天氣機率來計算其資訊熵:

序号 事件 機率P 資訊量
1 明天是晴天 0.5 − log ⁡ ( 0.5 ) -\log \left ( 0.5 \right ) −log(0.5)
2 明天出雨天 0.2 − log ⁡ ( 0.2 ) -\log \left ( 0.2 \right ) −log(0.2)
3 多雲 0.3 − log ⁡ ( 0.3 ) -\log \left ( 0.3 \right ) −log(0.3)

H ( X ) = − ( 0.5 ∗ log ⁡ ( 0.5 ) + 0.2 ∗ log ⁡ ( 0.2 ) + 0.3 ∗ log ⁡ ( 0.3 ) ) H\left ( \mathbf{X} \right ) = -\left ( 0.5 * \log \left ( 0.5 \right ) + 0.2 * \log \left ( 0.2 \right ) + 0.3 * \log \left ( 0.3 \right ) \right) H(X)=−(0.5∗log(0.5)+0.2∗log(0.2)+0.3∗log(0.3))

對于0-1分布的問題,由于其結果隻用兩種情況,是或不是,設某一件事情發生的機率為 P ( x ) P\left ( x \right ) P(x),則另一件事情發生的機率為 1 − P ( x ) 1-P\left ( x \right ) 1−P(x),是以對于0-1分布的問題,計算熵的公式可以簡化如下:

H ( X ) = − ∑ n = 1 n P ( x i log ⁡ ( P ( x i ) ) ) = − [ P ( x ) log ⁡ ( P ( x ) ) + ( 1 − P ( x ) ) log ⁡ ( 1 − P ( x ) ) ] = − P ( x ) log ⁡ ( P ( x ) ) − ( 1 − P ( x ) ) log ⁡ ( 1 − P ( x ) ) H\left ( \mathbf{X} \right ) = -\sum \limits_{n=1}^n P(x_{i}\log \left ( P \left ( x_{i} \right )) \right) \\ = -\left [ P\left ( x \right) \log \left ( P\left ( x \right ) \right ) + \left ( 1 - P\left ( x \right ) \right) \log \left ( 1-P\left ( x \right ) \right ) \right] \\ = -P\left ( x \right) \log \left ( P\left ( x \right ) \right ) - \left ( 1 - P\left ( x \right ) \right) \log \left ( 1-P\left ( x \right ) \right) H(X)=−n=1∑n​P(xi​log(P(xi​)))=−[P(x)log(P(x))+(1−P(x))log(1−P(x))]=−P(x)log(P(x))−(1−P(x))log(1−P(x))

相對熵(KL散度)

如果對于同一個随機變量 X X X有兩個單獨的機率分布 P ( x ) P\left(x\right) P(x)和 Q ( x ) Q\left(x\right) Q(x),則我們可以使用KL散度來衡量這兩個機率分布之間的差異。

下面直接列出公式,再舉例子加以說明。

D K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) log ⁡ ( p ( x i ) q ( x i ) ) D_{KL}\left ( p || q \right) = \sum \limits_{i=1}^n p\left ( x_{i}\right ) \log \left ( \frac{p\left ( x_{i} \right )}{q\left ( x_{i} \right )} \right ) DKL​(p∣∣q)=i=1∑n​p(xi​)log(q(xi​)p(xi​)​)

在機器學習中,常常使用 P ( x ) P\left(x\right) P(x)來表示樣本的真實分布, Q ( x ) Q \left(x\right) Q(x)來表示模型所預測的分布,比如在一個三分類任務中(例如,貓狗馬分類器), x 1 , x 2 , x 3 x_{1}, x_{2}, x_{3} x1​,x2​,x3​分别代表貓,狗,馬,例如一張貓的圖檔真實分布 P ( X ) = [ 1 , 0 , 0 ] P\left(X\right) = [1, 0, 0] P(X)=[1,0,0], 預測分布 Q ( X ) = [ 0.7 , 0.2 , 0.1 ] Q\left(X\right) = [0.7, 0.2, 0.1] Q(X)=[0.7,0.2,0.1],計算KL散度:

D K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) log ⁡ ( p ( x i ) q ( x i ) ) = p ( x 1 ) log ⁡ ( p ( x 1 ) q ( x 1 ) ) + p ( x 2 ) log ⁡ ( p ( x 2 ) q ( x 2 ) ) + p ( x 3 ) log ⁡ ( p ( x 3 ) q ( x 3 ) ) = 1 ∗ log ⁡ ( 1 0.7 ) = 0.36 D_{KL}\left ( p || q \right) = \sum \limits_{i=1}^n p\left ( x_{i}\right ) \log \left ( \frac{p\left ( x_{i} \right )}{q\left ( x_{i} \right )} \right ) \\ = p\left ( x_{1}\right ) \log \left ( \frac{p\left ( x_{1} \right )}{q\left ( x_{1} \right )} \right ) + p\left ( x_{2}\right ) \log \left ( \frac{p\left ( x_{2} \right )}{q\left ( x_{2} \right )} \right ) + p\left ( x_{3}\right ) \log \left ( \frac{p\left ( x_{3} \right )}{q\left ( x_{3} \right )} \right ) \\ = 1 * \log \left ( \frac{1}{0.7} \right ) = 0.36 DKL​(p∣∣q)=i=1∑n​p(xi​)log(q(xi​)p(xi​)​)=p(x1​)log(q(x1​)p(x1​)​)+p(x2​)log(q(x2​)p(x2​)​)+p(x3​)log(q(x3​)p(x3​)​)=1∗log(0.71​)=0.36

KL散度越小,表示 P ( x ) P\left(x\right) P(x)與 Q ( x ) Q\left(x\right) Q(x)的分布更加接近,可以通過反複訓練 Q ( x ) Q\left(x \right) Q(x)來使 Q ( x ) Q\left(x \right) Q(x)的分布逼近 P ( x ) P\left(x \right) P(x)。

交叉熵

首先将KL散度公式拆開:

D K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) log ⁡ ( p ( x i ) q ( x i ) ) = ∑ i = 1 n p ( x i ) l o g ( p ( x i ) ) − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) = − H ( p ( x ) ) + [ − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) ] D_{KL}\left ( p || q \right) = \sum \limits_{i=1}^n p\left ( x_{i}\right ) \log \left ( \frac{p\left ( x_{i} \right )}{q\left ( x_{i} \right )} \right ) \\ = \sum \limits_{i=1}^n p \left (x_{i}\right) log \left(p \left (x_{i}\right)\right) - \sum \limits_{i=1}^n p \left (x_{i}\right) log \left(q \left (x_{i}\right)\right) \\ = -H \left (p \left(x \right) \right) + \left [-\sum \limits_{i=1}^n p \left (x_{i}\right) log \left(q \left (x_{i}\right)\right) \right] DKL​(p∣∣q)=i=1∑n​p(xi​)log(q(xi​)p(xi​)​)=i=1∑n​p(xi​)log(p(xi​))−i=1∑n​p(xi​)log(q(xi​))=−H(p(x))+[−i=1∑n​p(xi​)log(q(xi​))]

前者 H ( p ( x ) ) H \left (p \left (x \right)\right) H(p(x))表示資訊熵,後者即為交叉熵,KL散度 = 交叉熵 - 資訊熵

交叉熵公式表示為:

H ( p , q ) = − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) H \left (p, q\right) = -\sum \limits_{i=1}^n p \left (x_{i}\right) log \left(q \left (x_{i}\right)\right) H(p,q)=−i=1∑n​p(xi​)log(q(xi​))

在機器學習訓練網絡時,輸入資料與标簽常常已經确定,那麼真實機率分布 P ( x ) P\left(x \right) P(x)也就确定下來了,是以資訊熵在這裡就是一個常量。由于KL散度的值表示真實機率分布 P ( x ) P\left(x\right) P(x)與預測機率分布 Q ( x ) Q \left(x\right) Q(x)之間的差異,值越小表示預測的結果越好,是以需要最小化KL散度,而交叉熵等于KL散度加上一個常量(資訊熵),且公式相比KL散度更加容易計算,是以在機器學習中常常使用交叉熵損失函數來計算loss就行了。

交叉熵在單分類問題中的應用

線上性回歸問題中,常常使用MSE(Mean Squared Error)作為loss函數,而在分類問題中常常使用交叉熵作為loss函數。

下面通過一個例子來說明如何計算交叉熵損失值。

假設我們輸入一張狗的圖檔,标簽與預測值如下:

*
Label 1
Pred 0.2 0.7 0.1

那麼loss

l o s s = − ( 0 ∗ log ⁡ ( 0.2 ) + 1 ∗ log ⁡ ( 0.7 ) + 0 ∗ log ⁡ ( 0.1 ) ) = 0.36 loss = -\left ( 0 * \log \left ( 0.2 \right ) + 1 * \log \left ( 0.7 \right ) + 0 * \log \left ( 0.1 \right )\right) = 0.36 loss=−(0∗log(0.2)+1∗log(0.7)+0∗log(0.1))=0.36

一個batch的loss為

l o s s = − 1 m ∑ i = 1 m ∑ j = 1 n p ( x i j ) l o g ( q ( x i j ) ) loss = -\frac{1}{m}\sum \limits_{i=1}^m \sum \limits_{j=1}^n p \left (x_{ij}\right) log \left(q \left (x_{ij}\right)\right) loss=−m1​i=1∑m​j=1∑n​p(xij​)log(q(xij​))

其中m表示樣本個數。

總結:

  • 交叉熵能夠衡量同一個随機變量中的兩個不同機率分布的差異程度,在機器學習中就表示為真實機率分布與預測機率分布之間的差異。交叉熵的值越小,模型預測效果就越好。
  • 交叉熵在分類問題中常常與softmax是标配,softmax将輸出的結果進行處理,使其多個分類的預測值和為1,再通過交叉熵來計算損失。

參考:

https://blog.csdn.net/tsyccnh/article/details/79163834

THE END

交叉熵損失函數原理詳解交叉熵損失函數原理詳解

繼續閱讀