
知識導圖
對于一個深度學習的訓練過程,可以将它描述為讓網絡輸出值和實際值越來越接近的過程。我們通過訓練優化器來完成這個過程,還需要一個評估函數來為我們的優化器指明方向。這個評估函數用來估量模型的預測值和真實值的不一緻程度,也就是所謂的損失函數。
Loss函數有很多,并且在很多的深度學習任務中,有時候是需要我們自行去根據任務相關來設計Loss函數的。
1. 回歸任務中的損失函數
1.1 MAE loss(L1)
L1 Loss 是一個衡量輸入x(模型預測輸出)和目标y之間差的絕對值的平均值,也叫MAE Loss。
由于L1 Loss 具有稀疏性,為了懲罰較大的值,是以常常将其作為正則項添加到其他Loss中作為限制。L1 Loss的最大問題是梯度在0點不平滑,導緻會跳過極小值。
在Pytorch中,L1 Loss的執行個體化類為:
class
其中N為一個batch的樣本數,參數reduction控制batch的loss取每個樣本L1 loss的均值還是總和,預設為mean。
1.2 MSE loss(L2)
L2 Loss是輸入x(模型預測輸出)和目标y之間均方誤差,是以也叫做MSE Loss:
同樣,L2 Loss也常常作為正則項。
當y和f(x)也就是真實值和預測值的內插補點大于1時,會放大誤差;而當內插補點小于1時,則會縮小誤差,這是平方運算決定的。MSE對于較大的誤差(>1)給予較大的懲罰,較小的誤差(<1)給予較小的懲罰。也就是說,對離群點比較敏感,受其影響較大。如果樣本中存在離群點,MSE會給離群點更高的權重,這就會犧牲其他正常點資料的預測效果,最終降低整體的模型性能。
在Pytorch中,L2 Loss的執行個體化類為:
class
同樣,N為一個batch的樣本數,參數reduction控制batch的loss取每個樣本L2 loss的均值還是總和,預設為mean。
1.3 選MSE還是MAE?
L1 Loss作為損失函數更穩定,并且對離群值不敏感,而且 L1 Loss 在0處不可導,大部分情況下梯度都是相等的,這意味着即使對于小的損失值,其梯度也是大的。這不利于函數的收斂和模型的學習。另外,在深度學習中,收斂較慢。L2 Loss導數求解速度高,但是其對離群值敏感,不過可以将離群值的導數設為0(導數值大于某個門檻值)來避免這種情況。
在實際的應用中,這兩種損失函數的選擇要視情況而定:從計算機求解梯度的複雜度來說,MSE 要優于 MAE,而且梯度也是動态變化的,能較快準确達到收斂。但是從離群點角度來看,如果離群點是實際資料或重要資料,而且是應該被檢測到的異常值,那麼我們應該使用MSE。另一方面,離群點僅僅代表資料損壞或者錯誤采樣,無須給予過多關注,那麼我們應該選擇MAE作為損失。
1.4 Huber loss 和 Smooth L1 loss
Huber loss結合了MSE和MAE,定義如下:
Huber Loss 包含了一個超參數 δ。δ 值的大小決定了 Huber Loss 對 MSE 和 MAE 的側重性,當 |y−f(x)| ≤ δ 時,變為 MSE;當 |y−f(x)| > δ 時,則變成類似于 MAE,是以 Huber Loss 同時具備了 MSE 和 MAE 的優點,減小了對離群點的敏感度問題,實作了處處可導的功能。
Smooth L1 loss就是Huber loss的參數δ取值為1時的形式。在Faster R-CNN以及SSD中對邊框的回歸使用的損失函數都是Smooth L1 loss。
Smooth L1 Loss 能從兩個方面限制梯度:
1.當預測框與 ground truth 差别過大時,梯度值不至于過大;
2.當預測框與 ground truth 差别很小時,梯度值足夠小
從上面可以看出,Smooth L1 loss函數實際上就是一個分段函數,在[-1,1]之間實際上就是L2損失,這樣解決了L1的不光滑問題,在[-1,1]區間外,實際上就是L1損失,這樣就解決了離群點梯度爆炸的問題。
在Pytorch中,Smooth L1 Loss的執行個體化類為:
class
2. 分類任務中的損失函數
2.1 交叉熵損失
2.1.1 什麼是交叉熵損失?(舉例)
在一個多分類任務中,交叉熵損失函數是非常常見的,其定義如下:
其中:
- [M] ——類别的數量;
- [y_c] ——訓示變量(0或1),如果該類别和樣本的類别相同就是1,否則是0;
- [p_c] ——對于觀測樣本屬于類别 [c] 的預測機率。
交叉熵,實際上就是真實标簽和預測标簽兩個分布的交叉熵。舉個例子:
假設一個5分類問題,然後一個樣本I的标簽y_c=[0,0,0,1,0],也就是說樣本I的真實标簽是4:
- 假設模型預測的結果機率p_c=[0.1,0.15,0.05,0.6,0.1],可以看出這個預測是對的,也就是類别4,那麼對應的損失值為L=-log(0.6)。
- 假設p_c=[0.15,0.2,0.4,0.1,0.15],這個預測結果就很離譜了,因為真實标簽是4,而你覺得這個樣本是4的機率隻有0.1(遠不如其他機率高,如果是在測試階段,那麼模型就會預測該樣本屬于類别3),對應損失值L=-log(0.1)。
- 假設p_c=[0.05,0.15,0.4,0.3,0.1],這個預測結果雖然也錯了,但是沒有前面那個那麼離譜,對應的損失L=-log(0.3)。
根據log函數的性質,有-log(0.6) < -log(0.3) < -log(0.1)。可以看出預測錯比預測對的損失要大,預測錯得離譜比預測錯得輕微的損失要大。
2.1.2 softmax loss
對于網絡層中常用的softmax loss,其實,在交叉熵損失的公式裡面,如果預測機率p_c是由softmax函數(softmax函數輸出向量為樣本在N個類别中,屬于每個類别的機率)得到的。那麼此時的softmax loss就是交叉熵loss。
2.1.3 Pytorch中的二分類交叉熵損失
在Pytorch中,交叉熵 Loss有幾個函數,其中,二分類的交叉熵為:
1.
對于BCELoss,由于二分類樣本的輸出隻有兩維,是以有:
其中參數reduction表示一個batch樣本loss的統計方式,預設為均值統計。API提供權重參數weight來調整loss值,weight是和分類次元一樣的tensor,一般weight預設即可。 BCEWithLogitsLoss相當于在BCELoss的基礎上加了sigmoid層:
這樣做的好處是可以使用一個tricks:log_sum_exp ,使得數值結果更加穩定,實際任務時,二分類交叉熵損失建議使用BCEWithLogitsLoss。
2.1.4 Pytorch中的多分類交叉熵損失
多分類任務的交叉熵loss為:
class
Pytorch的CrossEntropyLoss實際上做了這麼幾件事情:
1.計算了一層softmax:softmax函數會傳回樣本分類成每一個類别的機率分數,值在0~1之間。
2.将Softmax之後的結果取log,将乘法改成加法減少計算量,同時保障函數的單調性 . 3.上面的輸出與Label對應的那個值拿出來,乘以權重weight(用于資料樣本分布不均衡的調整),去掉負号,再求均值(reduction預設為mean)。Pytorch中也提供了兩個函數:
1.
而nn.CrossEntropyLoss的作用就相當于nn.LogSoftmax + nn.NLLLoss。 - nn.LogSoftmax完成上面的步驟1-2:
- nn.NLLLoss完成上面的步驟3(取出label對應的值):
這裡一個需要注意的點是nn.CrossEntropyLoss已經做了一次softmax,是以它的input在之前不需要再在網絡中添加一個softmax層了。
2.2 鉸鍊損失(Hinge loss)
鉸鍊損失的出名應用是作為SVM的損失函數,其名字來自于Hinge loss的圖像:
其中,$hat{y}$是預測值,為一機率分數,y是标簽值。與0-1損失相比,Hinge loss的圖像如下:
同樣對于多分類問題,Pytorch提供如下函數表示多分類hinge loss:
class
其中次數p一般預設為1。weight為根據樣本類别分布而設定的權重,可選擇性設定。margin為hinge的門檻值,就像圖像表示的函數,1也是margin值。x[i]為該樣本錯誤預測的得分,x[y]為正确預測的得分。兩者的內插補點可用來表示兩種預測結果的相似關系,margin是一個由自己指定的安全系數。我們希望正确預測的得分高于錯誤預測的得分,且高出一個邊界值 margin,換句話說,x[y]越高越好,x[i]越低越好,(x[y]–x[i])越大越好,(x[i]–x[y])越小越好,但二者得分之差最多為margin就足夠了,差距更大并不會有任何獎勵。這樣設計的目的在于,對單個樣本正确分類隻要有margin的把握就足夠了,更大的把握則不必要,過分注重單個樣本的分類效果反而有可能使整體的分類效果變壞。分類器應該更加專注于整體的分類誤差。
2.3 KL散度
KL散度也被稱為相對熵,常被用于生成模型,比如GAN。在資訊論中,關于熵有如下表述:
- 熵:可以表示一個事件P包含多少資訊。
- KL散度:可以表述事件P和事件P的拟合事件Q有多大不同
- 交叉熵:可以表述從事件P的角度如何去描述P的拟合事件Q。
前面說到的交叉熵,便是表達了預測事件和真實事件的相關程度,同樣,KL散度也同樣能描述兩個時間分布的關系,并作為損失函數使用。
上面公式是描述連續型事件分布的KL散度公式,不難發現,第一項便是之前說到的交叉熵的連續型,而後一項則是熵本身的定義,反映了事件P的資訊量大小,是以,對于真實事件P和預測事件Q,熵,相對熵(KL散度),交叉熵有如下關系:
P與Q的交叉熵 = P與Q的KL散度 - P的熵Pytorch中提供KLDivLoss函數來表述離散型KL散度損失:
class
對于一個N個樣本的batch,KL散度損失做如下計算:
參數reduction控制batch的loss取每個樣本loss的均值還是總和,預設為mean。
2.4 Triplet loss
Triplet loss用于訓練差異性較小的樣本,最初出現在FaceNet的論文中: FaceNet: A Unified Embedding for Face Recognition and Clustering ,可以學到較好的人臉的embedding。
Triplet loss的輸入是一個三元組:(anchor,positive, negative),其中,從訓練資料集中随機選一個樣本,該樣本稱為anchor,然後随機選取和anchor同類的樣本positive和不同類的樣本negative。下圖是人臉embedding産生的Triplet loss:
訓練模型使得Triplet loss最小就是拉近同類(anchor,positive)距離,拉遠異類(anchor,negative)距離,如下圖:
Triplet loss的公式如下:
在訓練的時候會得到很多的三元組(a,p,n),他們可以分為以下幾類:
- easy triplets :loss = 0,d(a, p) + margin < d(a, n),ap對的距離遠遠小于an對的距離。即,類内距離很小,類間很大距離,這種情況不需要優化。
- hard triplets :d(a, n) < d(a, p) ,ap對的距離大于于an對的距離,即類内距離大于類間距離。這種情況比較難優化。
- semi-hard triplets :d(a, p) < d(a, n) < d(a, p) + margin。ap對的距離和an對的距離比較高近。即,和很近,但都在一個margin内。比較容易優化。
一般在訓練的時候是随機選取semi-hard triplets 進行訓練的,但早期為了網絡loss平穩,一般選擇easy triplets進行優化,後期為了優化訓練關鍵是要選擇hard triplets,他們是活躍的,是以可以幫助改進模型。
Triplet loss有兩種訓練方法,
- OFFLINE : 将訓練集所有資料經過計算得到對應的 embeddings, 然後再計算 triplet loss,這種方式的效率不高,因為要周遊所有的資料得到三元組。
- ONLINE : 在ReID的論文:In Defense of the Triplet Loss for Person Re-Identification中使用了這樣的方式。在訓練時,分為Batch All和Batch Hard。Batch All計算了一個batch中所有val的的hard triplet 和 semi-hard triplet, 然後取平均得到Triplet loss。而Batch Hard則是對于每一個anchor,都選擇距離最大的d(a, p) 和距離最大的d(a, n)。 論文中選擇Batchhard,随機抽取P個人,每個人K張圖檔形成一個batch,每個人的K張圖檔之間形成K*(K-1)個ap對,再在剩下其他人裡取一個與該ap距離最近的negative,組成apn組并将apn組按照下面式子中的公式取模型裡進行訓練,使得下面的式子最小化。
Pytorch中提供TripletMarginLoss函數來實作Triplet loss,其中p為距離範數,預設為2,即2-範數:
class
3. PyTorch 如何自定義損失函數
關于PyTorch 如何自定義損失函數?總的來說,大體有以下方法:
3.1 調用torch.Tensor的原生接口
和一般的自定義函數一樣隻需要在
init()裡面定義好超參數,再在forward裡寫好計算過程就可以了。因為繼承了nn.Module,是以這個loss類在執行個體化之後可以直接運作
call()方法。 這裡以center loss為例(center loss來自于ECCV 2016 的一篇論文,被使用在ReID任務中,論文位址)
import
3.2 Pytorch使用numpy/scipy擴充
原生接口提供了torch.nn.functional子產品來代替一些函數操作,當該子產品功能不能滿足自定義函數的功能實作要求時,我們可以先将tensor轉換為numpy,再使用numpy/scipy來實作函數功能,最後再傳回tensor。下面是Pytorch官網給出的使用numpy/scipy擴充自定義快速傅裡葉變換的案例:
import
參考:
[1]:https://www.cnblogs.com/wangguchangqing/p/12021638.html
[2]:https://blog.csdn.net/wonengguwozai/article/details/74066157
[3]:https://msd.misuland.com/pd/2884250171976192486
[4]:https://mp.weixin.qq.com/s/Xbi5iOh3xoBIK5kVmqbKYA
[5]:https://blog.csdn.net/weixin_40671425/article/details/98068190
[6]:https://blog.csdn.net/weixin_45191152/article/details/97762005