天天看點

PyTorch示例——利用PyTorch計算梯度

PyTorch示例——利用PyTorch計算梯度

    • 版本資訊
    • 導包
    • 原始資料
    • 定義函數
    • 計算梯度
    • 繪制曲線:epoch與loss

版本資訊

  • PyTorch:

    1.12.1

  • Python:

    3.7.13

導包

import torch
import matplotlib.pyplot as plt
           

原始資料

# 原始資料
x_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
y_data = [0.4, 1.4, 2.8, 3.5, 4.8, 5.5]
plt.scatter(x_data, y_data)
plt.show()
           
PyTorch示例——利用PyTorch計算梯度

定義函數

def forward(x):
    # 前向傳播
    return x * w

def loss(x, y):
    # 計算損失
    y_pred = forward(x)
    return (y_pred - y) ** 2
           

計算梯度

  • 代碼
# 權重w,需要計算梯度
w = torch.Tensor([10.0])
w.requires_grad = True
# 學習率lr,不需要計算梯度
lr = torch.Tensor([0.005])
lr.requires_grad = False

loss_list = []
epochs = range(10)
for epoch in epochs:
    for x, y in zip(x_data, y_data):
        l = loss(x, y)
        l.backward() # 反向傳播計算每個點的梯度
        w.data = w.data - lr * w.grad.data # 學習
        w.grad.data.zero_() # 重置本次累積的梯度
    print(f"epoch = {epoch}, loss_val = {l.item()}")
    loss_list.append(l.item())
           
  • 輸出結果
epoch = 0, loss_val = 881.4310302734375
epoch = 1, loss_val = 107.05437469482422
epoch = 2, loss_val = 12.973189353942871
epoch = 3, loss_val = 1.5620064735412598
epoch = 4, loss_val = 0.18457315862178802
epoch = 5, loss_val = 0.020624106749892235
epoch = 6, loss_val = 0.0019251183839514852
epoch = 7, loss_val = 8.275721484096721e-05
epoch = 8, loss_val = 9.185609087580815e-06
epoch = 9, loss_val = 5.270535984891467e-05
           

繪制曲線:epoch與loss

plt.plot(epochs, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()
           
PyTorch示例——利用PyTorch計算梯度

繼續閱讀