tensor複制可以使用clone()函數和detach()函數即可實作各種需求。
clone
clone()函數可以傳回一個完全相同的tensor,新的tensor開辟新的記憶體,但是仍然留在計算圖中。
detach
detach()函數可以傳回一個完全相同的tensor,新的tensor開辟與舊的tensor共享記憶體,新的tensor會脫離計算圖,不會牽扯梯度計算。此外,一些原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在兩者任意一個執行都會引發錯誤。
使用分析
# Operation | New/Shared memory | Still in computation graph |
tensor.clone() | New | Yes |
tensor.detach() | Shared | No |
如下執行一些執行個體:
首先導入包并固定随機種子
import torch
torch.manual_seed(0)
1.clone()之後的tensor requires_grad=True,detach()之後的tensor requires_grad=False,但是梯度并不會流向clone()之後的tensor
x= torch.tensor([1., 2., 3.], requires_grad=True)
clone_x = x.clone()
detach_x = x.detach()
clone_detach_x = x.clone().detach()
f = torch.nn.Linear(3, 1)
y = f(x)
y.backward()
print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
Output:
--------------------------------------------
tensor([-0.0043, 0.3097, -0.4752])
True
None
False
False
--------------------------------------------
2.将計算圖中參與運算tensor變為clone()後的tensor。此時梯度仍然隻流向了原始的tensor。
x= torch.tensor([1., 2., 3.], requires_grad=True)
clone_x = x.clone()
detach_x = x.detach()
clone_detach_x = x.detach().clone()
f = torch.nn.Linear(3, 1)
y = f(clone_x)
y.backward()
print(x.grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
Output:
------------------------------------
tensor([-0.0043, 0.3097, -0.4752])
None
False
False
------------------------------------
3.将原始tensor設為requires_grad=False,clone()後的梯度設為.requires_grad_(),clone()後的tensor參與計算圖的運算,則梯度穿向clone()後的tensor。
x= torch.tensor([1., 2., 3.], requires_grad=False)
clone_x = x.clone().requires_grad_()
detach_x = x.detach()
clone_detach_x = x.detach().clone()
f = torch.nn.Linear(3, 1)
y = f(clone_x)
y.backward()
print(x.grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
Output:
--------------------------------------
None
tensor([-0.0043, 0.3097, -0.4752])
False
False
--------------------------------------
x = torch.tensor([1., 2., 3.], requires_grad=True)
f = torch.nn.Linear(3, 1)
w = f.weight.detach()
print(f.weight)
print(w)
y = f(x)
y.backward()
optimizer = torch.optim.SGD(f.parameters(), 0.1)
optimizer.step()
print(f.weight)
print(w)
Output:
----------------------------------------------------------
Parameter containing:
tensor([[-0.0043, 0.3097, -0.4752]], requires_grad=True)
tensor([[-0.0043, 0.3097, -0.4752]])
Parameter containing:
tensor([[-0.1043, 0.1097, -0.7752]], requires_grad=True)
tensor([[-0.1043, 0.1097, -0.7752]])
----------------------------------------------------------