天天看點

Tensorflow 損失函數(loss function)及自定義損失函數(一)

我主要分三篇文章給大家介紹tensorflow的損失函數,本篇為tensorflow内置的四個損失函數

(一)tensorflow内置的四個損失函數

(二)其他損失函數

(三)自定義損失函數

損失函數(loss function),量化了分類器輸出的結果(預測值)和我們期望的結果(标簽)之間的差距,這和分類器結構本身同樣重要。有很多的學者花費心思探讨如何改良損失函數使得分類器的結果最優,是以學會如何使用Tensorflow中的損失函數以及學會自己建構損失函數是非常重要的。

首先我們先規定一些變量這樣友善大家之後更加清楚的讀懂如何使用損失函數。

  1. Labels :标簽,在分類或者分割等問題中的标準答案。可以是1,2,3,4,5,6 。
  2. Labels_onehot : Onehot形式的标簽,即如果有3類那麼第一類表示為[1,0,0],第二類為[0,1,0],第三類為[0,0,1]。這種形式的标簽更加的常見。
  3. Network.out : 網絡最後一層的輸出,注意是沒有經過softmax的網絡的輸出,通常是softmax函數的輸入值。
  4. Network.probs : 網絡輸出的機率結果,通常為網絡最後一層輸出經過softmax函數之後的結果,Network.probs=tf.nn.softmax(Network.out)
  5. Network.pred : 網絡的預測結果,在onehot的形式中選擇機率最大的一類作為最終的預測結果,Network.pred=tf.argmax(Network.probs

    , axis=n)。

  6. Tensor : 一個張量,可以簡單的了解為是tensorflow中的一個數組。
  7. tf.reduce_sum(Tensor) : 降維加和,比如一個數組是333大小的,那麼經過這個操作之後會變為一個數字,即所有元素的加和。
  8. tf.reduce_mean(Tensor):降維平均,和上面的reduce_sum一樣,将高維的數組變為一個數,該數是數組中所有元素的均值。
Tensorflow 内置的四個損失函數 ↓

下面我們就進入正題啦。Tf内置的損失函數一共有四個,弄懂了一個其餘的就基本了解了,下面我們就逐一的介紹,其中第一個重點介紹,其餘的建立在大家對第一個了解的基礎之上。

  • ① Tensor=tf.nn.softmax_cross_entropy_with_logits(logits= Network.out, labels= Labels_onehot)

上面是softmax交叉熵loss,參數為網絡最後一層的輸出和onehot形式的标簽。切記輸入一定不要經過softmax,因為在函數中内置了softmax操作,如果再做就是重複使用了。在計算loss的時候,輸出Tensor要加上tf.reduce_mean(Tensor)或者tf.reduce_sum(Tensor),作為tensorflow優化器(optimizer)的輸入。

  • ② Tensor=tf.nn.sparse_softmax_cross_entropy_with_logits (logits=Network.out, labels= Labels)

這個函數和上面的差別就是labels參數應該是沒有經過onehot的标簽,其餘的都是一樣的。另外加了sparse的loss還有一個特性就是标簽可以出現-1,如果标簽是-1代表這個資料不再進行梯度回傳。

  • ③ Tensor=tf.nn. sigmoid_cross_entropy_with_logits (logits= Network.out, labels= Labels_onehot)

sigmoid交叉熵loss,與softmax不同的是,該函數首先進行sigmoid操作之後計算交叉熵的損失函數,其餘的特性與tf.nn.softmax_cross_entropy_with_logits一緻。

  • ④Tensor=tf.nn.weighted_cross_entropy_with_logits (logits=Network.out, labels=Labels_onehot, pos_weight=decimal_number)

這個loss與衆不同的地方就是加入了一個權重的系數,其餘的地方與tf.nn. sigmoid_cross_entropy_with_logits這個損失函數是一緻的,加入的pos_weight函數可以适當的增大或者縮小正樣本的loss,可以一定程度上解決正負樣本數量差距過大的問題。對比下面兩個公式我們可以更加清晰的看到,他和sigmoid的損失函數的差別,對于普通的sigmoid來說計算的形式如下:

targets * -log(sigmoid(logits)) + (1 - targets) * -log(1 - sigmoid(logits))

加入了pos_weight之後的計算形式如下:

targets * -log(sigmoid(logits)) * pos_weight + (1 - targets) * -log(1 - sigmoid(logits))

繼續閱讀