pytorch 實作簡單線性回歸
問題描述:
使用 pytorch 實作一個簡單的線性回歸。
受教育年薪與收入資料集
單變量線性回歸
單變量線性回歸算法(比如,$x$ 代表學曆,$f(x)$ 代表收入):
$f(x) = w*x + b $
我們使用 $f(x)$ 這個函數來映射輸入特征和輸出值。
目标:
預測函數 $f(x)$ 與真實值之間的整體誤差最小。
損失函數:
使用均方差作為作為成本函數。
也就是預測值和真實值之間差的平方取均值。
成本函數與損失函數:
優化的目标( $y$ 代表實際的收入):
找到合适的 $w$ 和 $b$ ,使得 $(f(x) - y)^{2}$越小越好
注意:現在求解的是參數 $w$ 和 $b$。
過程
1 導入實驗所需要的包
2 讀取資料
3 檢視資料資訊
檢視資料
檢視資料類型
4 圖表顯示資料
5 轉換資料為 tensor 類型
檢視特征資料
檢視特征資料 index
檢視特征資料 value
特征資料變換形狀
檢視特征資料變換後的形狀
檢視特征資料變換後的資料類型
修改特征資料變換後的資料類型
特征資料和标簽轉換為tensor
6 定義模型
定義線性回歸模型:
定義均方損失函數
定義優化器
7 模型訓練
8 輸出權重和偏置
tensor 類型資料帶梯度轉換為numpy需要先去梯度
9 擷取預測值 y_pred
預測值類型
預測值size
10 繪制回歸曲線
完整代碼:
view code