天天看點

Pytorch 實作簡單線性回歸

  pytorch 實作簡單線性回歸

問題描述:

  使用 pytorch 實作一個簡單的線性回歸。

      

Pytorch 實作簡單線性回歸

            受教育年薪與收入資料集

單變量線性回歸

  單變量線性回歸算法(比如,$x$ 代表學曆,$f(x)$ 代表收入): 

    $f(x) = w*x + b $

  我們使用 $f(x)$ 這個函數來映射輸入特征和輸出值。

目标:

  預測函數 $f(x)$ 與真實值之間的整體誤差最小。

損失函數: 

  使用均方差作為作為成本函數。

  也就是預測值和真實值之間差的平方取均值。

成本函數與損失函數: 

  優化的目标( $y$ 代表實際的收入):

  找到合适的 $w$ 和 $b$ ,使得 $(f(x) - y)^{2}$越小越好

  注意:現在求解的是參數 $w$ 和 $b$。

過程

1 導入實驗所需要的包

2 讀取資料

3 檢視資料資訊

  檢視資料

Pytorch 實作簡單線性回歸

   檢視資料類型

4 圖表顯示資料

Pytorch 實作簡單線性回歸

5 轉換資料為 tensor 類型

檢視特征資料

檢視特征資料 index

檢視特征資料 value

特征資料變換形狀

檢視特征資料變換後的形狀

檢視特征資料變換後的資料類型

修改特征資料變換後的資料類型

特征資料和标簽轉換為tensor

6 定義模型

定義線性回歸模型:

定義均方損失函數

定義優化器

7 模型訓練

8 輸出權重和偏置

  tensor 類型資料帶梯度轉換為numpy需要先去梯度

9 擷取預測值 y_pred

預測值類型

預測值size

10 繪制回歸曲線

 完整代碼:

Pytorch 實作簡單線性回歸
Pytorch 實作簡單線性回歸

view code