天天看點

pytorch 多次backward

pytorch 多次backward

影響pytorch梯度的參數:

1.requires_grad 需要為True

for param in model.named_parameters():
    if param[0] in need_frozen_list:
        param[1].requires_grad = False
           

2.model.train(True)

如果網絡需要多次求loss,多次求導反向傳播:

如果我們再調用一次 backward,會發現程式報錯,沒有辦法再做一次。這是因為 PyTorch 預設做完一次自動求導之後,計算圖就被丢棄了,是以兩次自動求導需要手動設定一個東西,我們通過下面的小例子來說明。

import torch
from torch.autograd import Variable
x = Variable(torch.FloatTensor([3]), requires_grad=True)
y = x * 2 + x ** 2 + 3
print(y)
y.backward(retain_graph=True) # 設定 retain_graph 為 True 來保留計算圖
print(x.grad)
y.backward() # 再做一次自動求導,這次不保留計算圖
print(x.grad)
結果:

tensor([18.], grad_fn=<AddBackward0>)
tensor([8.])
tensor([16.])
           

可以發現 x 的梯度變成了 16,因為這裡做了兩次自動求導,是以講第一次的梯度 8 和第二次的梯度 8 加起來得到了 16 的結果。

繼續閱讀