天天看點

tensorflow2.0損失函數總結和損失函數自定義

當我們嘗試訓練神經網絡的時候,不可避免地要接觸到損失函數,損失函數計算真實值和預測值的誤差。tensorflow2.0已經給我們封裝好的具備很多用途的損失函數,我們可以隻用兩行代碼就可以直接使用,簡直友善地不要不要的。

我先說如何使用,再說有哪些可以供我們挑選使用

如何使用看下面代碼,分析過程在代碼的注釋裡面,注意看代碼注釋,注意看代碼注釋,注意看代碼注釋。

from tensorflow.keras import losses

# 假設y_true是真實值, y_pred是網絡預測值
import tensorflow as tf

y_true = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
y_pred = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))

# 執行個體化一個損失對象
loss_object = losses.CategoricalCrossentropy()  # 這個類裡面的參數很值的研究,一般都是預設即可
"""
類似于這種CategoricalCrossentropy類,tensorflow2.0給我提供了幾種呢?
答案是有好多好多,具體分析接着看部落格。你也可以按住ctrl健點選“losses”,進入源碼裡面看看
"""

# 通過該損失對象計算損失
losses_hjx = loss_object(y_true=y_true, y_pred=y_pred)  # losses_hjx為tf.Tensor(15.427775, shape=(), dtype=float32)


pass      

我覺得,就算我現在把所有常用的損失函數都告訴你了,我相信你也是一刷而過,絲毫沒有覺得有用的感覺,反而感覺壓力很大。是以我就不把這些内置的損失函數逐一告訴你們了。換言之,我們完全可以不用别人寫好的東西呀,我們想要什麼就自己來自定義什麼呗,難道不是很快樂嗎?

是以

接着我要告訴你們如何自定義損失函數,當然啦,tensorflow2.0确實已經給我們做好了太多東西了,你可以直接使用他們的内置函數。想知道還有哪些内置函數的童鞋,評論區call我,我發給你10G資料研究研究啧啧啧。

如何自定義損失函數,代碼分析在注釋裡面,注意看代碼注釋,注意看代碼注釋,注意看代碼注釋

from tensorflow.keras import losses

# 假設y_true是真實值, y_pred是網絡預測值
import tensorflow as tf

y_true = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
y_pred = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))


class FocalLoss(losses.Loss):  # 繼承Loss類

    # 重寫初始化方法,其實就是定義一些自己損失邏輯可能使用到的參數,格式如下
    def __init__(self, gamma=2.0, alpha=0.25, **kwargs):
        super(FocalLoss, self).__init__(**kwargs)
        self.gamma = gamma
        self.alpha = alpha

    # call函數是重點,重寫了損失函數的運算邏輯,這也是一個損失函數的本質了,下面損失邏輯是我随便寫的
    def call(self, y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        loss = pt_0 + pt_1
        return loss


# 使用方法跟内置損失函數的使用方法一樣,看下面
# 執行個體化一個損失對象
loss_object = FocalLoss(name='focalloss')  # 這個類裡面的參數必須要傳遞一個參數name,name的值可以自定義

# 通過該損失對象計算損失
losses_hjx = loss_object(y_true=y_true, y_pred=y_pred)  # losses_hjx為tf.Tensor(15.427775, shape=(), dtype=float32)

pass      

有一些童鞋可能聰明一點,現在是不是在想,為啥我不直接以一個函數的形式來實作這個損失函數對吧。如果你想到這個問題,說明你太聰明了。

對的

為啥我們不使用自己寫的普通形式的函數來定義損失函數呢。

原因就是

繼續閱讀