天天看點

sigmoid 函數的損失函數與參數更新1 sigmoid 函數的損失函數與參數更新

1 sigmoid 函數的損失函數與參數更新

邏輯回歸對應線性回歸,但旨在解決分類問題,即将模型的輸出轉換為 $[0, 1]$ 的機率值。邏輯回歸直接對分類的可能性進行模組化,無需事先假設資料的分布。最理想的轉換函數為機關階躍函數(也稱 Heaviside 函數),但機關階躍函數是不連續的,沒法在實際計算中使用。故而,在分類過程中更常使用對數幾率函數(即 sigmoid 函數):

$$

\sigma(x) = \frac{1}{1+e^{-x}}

易推知,$\sigma(x)' = \sigma(x)(1- \sigma(x))$.

假設我們有 $m$ 個樣本 $D = \{(x_i, y_i)\}_i^m$, 令 $X = (x_1, x_2, \cdots, x_m)^T, y = (y_1, y_2, \cdots, y_m)^T$, 其中 $x_i \in \mathbb{R}^n, y_i \in \{0, 1\}$, 關于參數 $w \in \mathbb{R}^n, b \in \mathbb{R}$, ($b$ 需要廣播操作),我們定義正例的機率為

P(y_j=1|x_j;w,b) = \sigma(x_j^Tw +b) = \sigma(z_j)

這樣屬于類别 $y$ 的機率可改寫為

P(y_j|x_j;w,b) = \sigma(z_j)^{y_j}(1-\sigma(z_j))^{1-y_j}

令 $z = (z_1, \cdots, z_m)^T$, 則記 $h(z) = (\sigma(z_1), \cdots, \sigma(z_m))^T$, 且 Logistic Regression 的損失函數為

\begin{aligned}

L(w, b) =& - \displaystyle \frac{1}{m} \sum_{i=1}^m (y_i \log (\sigma(z_i)) +(1-y_i) \log (1 - \sigma(z_i)))\\

=& - \frac{1}{m} (y^T\log (h(z)) + (\mathbf{1}-y)^T\log(\mathbf{1}- h(z))), \text{ 此時做了廣播操作}

\end{aligned}

這樣,我們有

\begin{cases}

\nabla_w L(w,b) = \frac{\text{d}z}{\text{d}w} \frac{\text{d}L}{\text{d}z} = - \frac{1}{m}X^T(y-h(z))\\

\nabla_b L(w,b) = \frac{\text{d}z}{\text{d}b} \frac{\text{d}L}{\text{d}z} = - \frac{1}{m}\mathbf{1}^T(y-h(z))

\end{cases}

其中,$\mathbf{1}$ 表示全一列向量。這樣便有參數更新公式 ($\eta$ 為學習率):

w \leftarrow w - \eta \nabla_{w} L(w,b)\\

b \leftarrow b - \eta \nabla_b L(w,b)

更多機器學習中的數見:

機器學習中的數學