可見,使用 MSE 損失函數,受離群點的影響較大,雖然樣本中隻有 5 個離群點,但是拟合的直線還是比較偏向于離群點。
從上面可以看出,該函數實際上就是一個分段函數,在[-1,1]之間實際上就是L2損失,這樣解決了L1的不光滑問題,在[-1,1]區間外,實際上就是L1損失,這樣就解決了離群點梯度爆炸的問題
實作 (PyTorch)
def _smooth_l1_loss(input, target, reduction='none'):
# type: (Tensor, Tensor) -> Tensor
t = torch.abs(input - target)
ret = torch.where(t < 1, 0.5 * t ** 2, t - 0.5)
if reduction != 'none':
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
return ret
也可以添加個參數
beta
這樣就可以控制,什麼範圍的誤差使用MSE,什麼範圍内的誤差使用MAE了。
def smooth_l1_loss(input, target, beta=1. / 9, reduction = 'none'):
"""
very similar to the smooth_l1_loss from pytorch, but with
the extra beta parameter
"""
n = torch.abs(input - target)
cond = n < beta
ret = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
if reduction != 'none':
ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
return ret
總結
對于大多數CNN網絡,我們一般是使用L2-loss而不是L1-loss,因為L2-loss的收斂速度要比L1-loss要快得多。
對于邊框預測回歸問題,通常也可以選擇平方損失函數(L2損失),但L2範數的缺點是當存在離群點(outliers)的時候,這些點會占loss的主要組成部分。比如說真實值為1,預測10次,有一次預測值為1000,其餘次的預測值為1左右,顯然loss值主要由1000決定。是以FastRCNN采用稍微緩和一點絕對損失函數(smooth L1損失),它是随着誤差線性增長,而不是平方增長。
Smooth L1 和 L1 Loss 函數的差別在于,L1 Loss 在0點處導數不唯一,可能影響收斂。Smooth L1的解決辦法是在 0 點附近使用平方函數使得它更加平滑。
Smooth L1的優點
- 相比于L1損失函數,可以收斂得更快。
- 相比于L2損失函數,對離群點、異常值不敏感,梯度變化相對更小,訓練時不容易跑飛。
smooth L1 loss能從兩個方面限制梯度:
- 當預測框與 ground truth 差别過大時,梯度值不至于過大;
- 當預測框與 ground truth 差别很小時,梯度值足夠小。
考察如下幾種損失函數,其中
損失函數對 x 的導數分别為:
觀察 (4),當 x 增大時 L2 損失對 x 的導數也增大。這就導緻訓練初期,預測值與 groud truth 差異過于大時,損失函數對預測值的梯度十分大,訓練不穩定。
根據方程 (5),L1 對 x 的導數為常數。這就導緻訓練後期,預測值與 ground truth 差異很小時, L1 損失對預測值的導數的絕對值仍然為 1,而 learning rate 如果不變,損失函數将在穩定值附近波動,難以繼續收斂以達到更高精度。
最後觀察 (6),smooth L1 在 x 較小時,對 x 的梯度也會變小,而在 x 很大時,對 x 的梯度的絕對值達到上限 1,也不會太大以至于破壞網絡參數。 smooth L1 完美地避開了 L1 和 L2 損失的缺陷。其函數圖像如下:
由圖中可以看出,它在遠離坐标原點處,圖像和 L1 loss 很接近,而在坐标原點附近,轉折十分平滑,不像 L1 loss 有個尖角,是以叫做 smooth L1 loss。
參考文獻:
https://blog.csdn.net/weixin_41940752/article/details/93159710
https://www.cnblogs.com/wangguchangqing/p/12021638.html
https://www.jianshu.com/p/19483787fa24