當我們嘗試訓練神經網絡的時候,不可避免地要接觸到損失函數,損失函數計算真實值和預測值的誤差。tensorflow2.0已經給我們封裝好的具備很多用途的損失函數,我們可以隻用兩行代碼就可以直接使用,簡直友善地不要不要的。
我先說如何使用,再說有哪些可以供我們挑選使用
如何使用看下面代碼,分析過程在代碼的注釋裡面,注意看代碼注釋,注意看代碼注釋,注意看代碼注釋。
from tensorflow.keras import losses
# 假設y_true是真實值, y_pred是網絡預測值
import tensorflow as tf
y_true = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
y_pred = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
# 執行個體化一個損失對象
loss_object = losses.CategoricalCrossentropy() # 這個類裡面的參數很值的研究,一般都是預設即可
"""
類似于這種CategoricalCrossentropy類,tensorflow2.0給我提供了幾種呢?
答案是有好多好多,具體分析接着看部落格。你也可以按住ctrl健點選“losses”,進入源碼裡面看看
"""
# 通過該損失對象計算損失
losses_hjx = loss_object(y_true=y_true, y_pred=y_pred) # losses_hjx為tf.Tensor(15.427775, shape=(), dtype=float32)
pass
我覺得,就算我現在把所有常用的損失函數都告訴你了,我相信你也是一刷而過,絲毫沒有覺得有用的感覺,反而感覺壓力很大。是以我就不把這些内置的損失函數逐一告訴你們了。換言之,我們完全可以不用别人寫好的東西呀,我們想要什麼就自己來自定義什麼呗,難道不是很快樂嗎?
是以
接着我要告訴你們如何自定義損失函數,當然啦,tensorflow2.0确實已經給我們做好了太多東西了,你可以直接使用他們的内置函數。想知道還有哪些内置函數的童鞋,評論區call我,我發給你10G資料研究研究啧啧啧。
如何自定義損失函數,代碼分析在注釋裡面,注意看代碼注釋,注意看代碼注釋,注意看代碼注釋
from tensorflow.keras import losses
# 假設y_true是真實值, y_pred是網絡預測值
import tensorflow as tf
y_true = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
y_pred = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
class FocalLoss(losses.Loss): # 繼承Loss類
# 重寫初始化方法,其實就是定義一些自己損失邏輯可能使用到的參數,格式如下
def __init__(self, gamma=2.0, alpha=0.25, **kwargs):
super(FocalLoss, self).__init__(**kwargs)
self.gamma = gamma
self.alpha = alpha
# call函數是重點,重寫了損失函數的運算邏輯,這也是一個損失函數的本質了,下面損失邏輯是我随便寫的
def call(self, y_true, y_pred):
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
loss = pt_0 + pt_1
return loss
# 使用方法跟内置損失函數的使用方法一樣,看下面
# 執行個體化一個損失對象
loss_object = FocalLoss(name='focalloss') # 這個類裡面的參數必須要傳遞一個參數name,name的值可以自定義
# 通過該損失對象計算損失
losses_hjx = loss_object(y_true=y_true, y_pred=y_pred) # losses_hjx為tf.Tensor(15.427775, shape=(), dtype=float32)
pass
有一些童鞋可能聰明一點,現在是不是在想,為啥我不直接以一個函數的形式來實作這個損失函數對吧。如果你想到這個問題,說明你太聰明了。
對的
為啥我們不使用自己寫的普通形式的函數來定義損失函數呢。
原因就是