天天看點

Pytorch中requires_grad_(), tensor.data,detach(), torch.no_grad()的差別

一、tensor.data的使用 (屬性)

import torch
 
a = torch.tensor([1,2,3.], requires_grad = True)
out = a.sigmoid()
c = out.data  # 通過.data “分離”得到的的變量會和原來的變量共用資料(指向同一位址),而且新分離得到的張量是不可求導的
c.zero_()     # 改變c的值,原來的out也會改變
print(c.requires_grad)   #false
print(c)      #tensor([0., 0., 0.])
print(out.requires_grad)     #True
print(out)    #tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
print("----------------------------------------------")
 
out.sum().backward() # 對原來的out求導,
print(a.grad)  # 不會報錯,但是結果卻并不正确

           

.data的兩點總結:

(1)tensor .data 将傳回相同資料的 tensor,而且兩個tensor共用資料,一者改變,另一者也會跟着改變。新分離得到的tensor的require s_grad = False, 即不可求導的。(這一點其實detach是一樣的)

(2)使用tensor.data的局限性。文檔中說使用tensor.data是不安全的, 因為 x.data 不能被 autograd 追蹤求微分 。什麼意思呢?從上面的例子可以看出,由于我更改分離之後的變量值c,導緻原來的張量out的值也跟着改變了,但是這種改變對于autograd是沒有察覺的,它依然按照求導規則來求導,導緻得出完全錯誤的導數值卻渾然不知。它的風險性就是如果我再任意一個地方更改了某一個張量,求導的時候也沒有通知我已經在某處更改了,導緻得出的導數值完全不正确,故而風險大。這種方式已經基本被淘汰了。

二、detach() (方法)

import torch
 
a = torch.tensor([1,2,3.], requires_grad = True)
out = a.sigmoid()
c = out.detach()  # 通過.detach() “分離”得到的的變量也會與原張量使用同一資料,而且新分離得到的張量是不可求導的
c.zero_()         # 改變c的值,原來的out也會改變
print(c.requires_grad)    #false
print(c)          #tensor([0., 0., 0.])
print(out.requires_grad)  #true
print(out)        #tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
print("----------------------------------------------")
 
out.sum().backward()      # 對原來的out求導,
print(a.grad)     # 此時會報錯,監測到梯度計算所需要的張量已經被“原位操作inplace”所更改了

           

detach()的兩點總結:

(1)tensor.detach() 将傳回相同資料的 tensor,而且兩個tensor共用資料,一者改變,另一者也會跟着改變。新分離得到的tensor的require s_grad = False, 即不可求導的。(這一點和.data是一樣的)

(2)使用tensor.detach()的優點。從上面的例子可以看出,detach存在監測機制。由于我更改分離之後的變量值c,導緻原來的張量out的值也跟着改變了,這個時候如果依然按照求導規則來求導肯定會錯誤,是以不會再繼續求導了。

三、requires_grad_()

需要注意的是,這個是函數,與requires_grad屬性不同。

這個函數的作用是改變requires_grad屬性并傳回tensor,修改requires_grad屬性是inplace操作,預設參數為requires_grad=True。

Pytorch中requires_grad_(), tensor.data,detach(), torch.no_grad()的差別

注意:隻開不關

四、torch.no_grad()

用來禁止梯度的計算,常用在網絡推斷。

被with torch.no_grad包起來的操作,仍會運作或計算,但是他們的requires_grad屬性會被賦為False。進而在計算圖中關閉這些操作的梯度計算。

Pytorch中requires_grad_(), tensor.data,detach(), torch.no_grad()的差別

當不需要進行反向傳播或梯度計算時,requires_grad=True的變量會占用很多計算資源及存儲資源。with torch.no_grad作用範圍中的操作不會建構計算圖。

參考:https://blog.csdn.net/qq_27825451/article/details/96837905