天天看點

梯度累加是什麼意思-詳解

每次看到“梯度是累加的,是以需要清除梯度”這句話都感覺雲裡霧裡,貌似懂了實際沒懂,原來竟是這麼簡單的意思!

1、直接上代碼:

import torch

x = torch.Tensor([1, 2, 3])
x.requires_grad_()
print(x)
y = x**2

 # 連續調用backward時,需要retain_graph=True確定計算圖暫時不被釋放
y.sum().backward(retain_graph=True) 
print(x.grad)
y.sum().backward()
print(x.grad)  # 如果梯度不歸零的話,梯度是累加的
           

運作結果是:

tensor([1., 2., 3.], requires_grad=True)
tensor([2., 4., 6.])
tensor([ 4.,  8., 12.])
           

第一次調用backward反向傳播,結果是(2 4 6),中間沒有梯度清零,第二次調用backward反向傳播,又有了一波結果(2 4 6),加在之前的結果上就得了(4 8 12)

2、接下來,我們在兩次調用之間加一個梯度清零操作看看:

import torch

x = torch.Tensor([1, 2, 3])
x.requires_grad_()
print(x)
y = x**2

y.sum().backward(retain_graph=True)  # 連續調用backward時,需要retain_graph=True確定計算圖暫時不被釋放
print(x.grad)
x.grad.zero_()
y.sum().backward()
print(x.grad)  # 如果梯度不歸零的話,梯度是累加的
           

運作結果是:

tensor([1., 2., 3.], requires_grad=True)
tensor([2., 4., 6.])
tensor([2., 4., 6.])
           

繼續閱讀