天天看點

一文弄懂LogSumExp技巧

引言

今天來學習下LogSumExp(LSE)​​1​​​技巧,主要解決計算Softmax或CrossEntropy​​2​​時出現的上溢(overflow)或下溢(underflow)問題。

我們知道程式設計語言中的數值都有一個表示範圍的,如果數值過大,超過最大的範圍,就是上溢;如果過小,超過最小的範圍,就是下溢。

什麼是LSE

LSE被定義為參數指數之和的對數:

輸入可以看成是一個n維的向量,輸出是一個标量。

為什麼需要LSE

在機器學習中,計算機率輸出基本都需要經過Softmax函數,它的公式應該很熟悉了吧

但是Softmax存在上溢和下溢大問題。如果太大,對應的指數函數也非常大,此時很容易就溢出,得到​​

​nan​

​​結果;如果太小,或者說負的太多,就會導緻出現下溢而變成0,如果分母變成0,就會出現除0的結果。

此時我們經常看到一個常見的做法是(其實用到的是指數歸一化技巧, exp-normalize​​3​​​),先計算中的最大值,然後根據

這種轉換是等價的,經過這一變換,就避免了上溢,最大值變成了;同時分母中也會有一個1,就避免了下溢。

我們通過執行個體來了解一下。

def bad_softmax(x):
  y = np.exp(x)
  return y / y.sum()
 
x = np.array([1, -10, 1000])
print(bad_softmax(x))      
... RuntimeWarning: overflow encountered in exp
... RuntimeWarning: invalid value encountered in true_divide
array([ 0.,  0., nan])      

接下來進行上面的優化,并進行測試:

def softmax(x):
  b = x.max()
  y = np.exp(x - b)
  return y / y.sum()
 
print(softmax(x))      
array([0., 0., 1.])      

我們再看下是否會出現下溢:

x = np.array([-800, -1000, -1000])
print(bad_softmax(x))
# array([nan, nan, nan])
print(softmax(x))
# array([1.00000000e+00, 3.72007598e-44, 3.72007598e-44])      

嗯,看來解決了這個兩個問題。

一文弄懂LogSumExp技巧

等等,不是說LSE嗎,怎麼整了個什麼歸一化技巧。

好吧,回到LSE。

我們對Softmax取對數,得到:

因為上面最後一項也有上溢的問題,是以應用同樣的技巧,得

同樣是取中的最大值。

這樣,我們就得到了LSE的最終表示:

此時,Softmax也可以這樣表示:

對LogSumExp求導就得到了exp-normalize(Softmax)的形式,

那我們是使用exp-normalize還是使用LogSumExp呢?

如果你需要保留Log空間,那麼就計算,此時使用LogSumExp技巧;如果你隻需要計算Softmax,那麼就使用exp-normalize技巧。

怎麼實作LSE

實作LSE就很簡單了,我們通過代碼實作一下。

def logsumexp(x):
  b = x.max()
  return b + np.log(np.sum(np.exp(x - b)))
 
def softmax_lse(x):
  return np.exp(x - logsumexp(x))      

上面是基于LSE實作了Softmax,下面測試一下:

> x1 = np.array([1, -10, 1000])
> x2 = np.array([-900, -1000, -1000])
> softmax_lse(x1)
array([0., 0., 1.])
> softmax(x1)
array([0., 0., 1.])
> softmax_lse(x2)
array([1.00000000e+00, 3.72007598e-44, 3.72007598e-44])
> softmax(x2)
> array([1.00000000e+00, 3.72007598e-44, 3.72007598e-44])      

最後我們看一下數值穩定版的Sigmoid函數

數值穩定的Sigmoid函數

我們知道Sigmoid函數公式為:

對應的圖像如下:

一文弄懂LogSumExp技巧

其中包含一個,我們看一下的圖像:

一文弄懂LogSumExp技巧

從上圖可以看出,如果很大,會非常大,而很小就沒事,變成無限接近。

當Sigmoid函數中的負的特别多,那麼就會變成,就出現了上溢;

那麼如何解決這個問題呢?可以表示成兩種形式:

當時,我們根據的圖像,我們取的形式;

# 原來的做法
def sigmoid_naive(x):
  return 1 / (1 + math.exp(-x))
  
# 優化後的做法
def sigmoid(x):
  if x < 0:
    return math.exp(x) / (1 + math.exp(x))
  else:
    return 1 / (1 + math.exp(-x))      
> sigmoid_naive(2000)
1.0
> sigmoid(2000)
1.0
> sigmoid_naive(-2000)
OverflowError: math range error
> sigmoid(-2000)
0.0      

References

  1. ​​The Log-Sum-Exp Trick ​​​ ​​↩︎​​
  2. ​​一文弄懂交叉熵損失 ​​​ ​​↩︎​​
  3. ​​Exp-normalize trick​​​ ​​↩︎​​

繼續閱讀