文章目錄
-
-
- 一、Hook函數概念
- 二、四種Hook函數介紹
-
- 1. Tensor.register_hook
- 2. Module.register_forward_hook
- 3.Module.register_forward_pre_hook
- 4. Module.register_backward_hook
-
一、Hook函數概念
Hook函數機制:不改變主體,實作額外功能,像一個挂件一樣将功能挂到函數主體上。Hook函數與PyTorch中的動态圖運算機制有關,因為在動态圖計算,在運算結束後,中間變量是會被釋放掉的,例如:非葉子節點的梯度。但是,我們往往想要提取這些中間變量,這時,我們就可以采用Hook函數在前向傳播與反向傳播主體上挂上一些額外的功能(函數),通過這些函數擷取中間的梯度,甚至是改變中間的梯度。PyTorch一共提供了四種Hook函數:
- torch.Tensor.register_hook(hook)
- torch.nn.Module.register_forward_hook
- torch.nn.Module.register_forward_pre_hook
- torch.nn.Module.register_backward_hook
一種是針對Tensor,其餘三種是針對網絡的
二、四種Hook函數介紹
1. Tensor.register_hook
def register_hook(self, hook):
"""
接受一個hook函數
"""
...
功能:注冊一個反向傳播hook函數,這是因為張量在反向傳播的時候,如果不是葉子節點,它的梯度就會消失。由于反向傳播過程中存在資料的釋放,是以就有了反向傳播的hook函數
- Hook函數僅一個輸入參數,為張量的梯度
下面,我們通過計算圖流程來觀察張量梯度的擷取以及熟悉Hook函數。
y = ( x + w ) ∗ ( w + 1 ) y=(x+w)*(w+1) y=(x+w)∗(w+1)
import torch
import torch.nn as nn
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 存儲張量的梯度
a_grad = list()
def grad_hook(grad):
"""
定義一個hook函數,将梯度存儲到清單中
:param grad: 梯度
:return:
"""
a_grad.append(grad)
# 注冊一個反向傳播的hook函數,功能是将梯度存儲到a_grad清單中
handle = a.register_hook(grad_hook)
# 反向傳播
y.backward()
# 檢視梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()
tensor([5.]) tensor([2.]) None None None
a_grad[0]: tensor([2.])
如果對葉子節點的張量使用hook函數,那麼會怎麼樣呢?
import torch
import torch.nn as nn
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list()
def grad_hook(grad):
grad *= 2
return grad*3
handle = w.register_hook(grad_hook)
y.backward()
# 檢視梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("w.grad: ", w.grad)
handle.remove()
gradient: tensor([30.]) tensor([2.]) None None None
w.grad: tensor([30.])
與上面比較,發現hook函數相當于對已有張量進行原地操作
2. Module.register_forward_hook
def register_forward_hook(self, hook):
...
功能:注冊module的前向傳播hook函數
參數:
- module: 目前網絡層
- input:目前網絡層輸入資料
- output:目前網絡層輸出資料
3.Module.register_forward_pre_hook
功能:注冊module前向傳播前的hook函數
參數:
- module: 目前網絡層
- input:目前網絡層輸入資料
4. Module.register_backward_hook
功能:注冊module反向傳播的hook函數
參數:
- module: 目前網絡層
- grad_input:目前網絡層輸入梯度資料
- grad_output:目前網絡層輸出梯度資料
下面例子展示這三個hook函數
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
def forward_hook(module, data_input, data_output):
"""
定義前向傳播hook函數
:param module:網絡
:param data_input:輸入資料
:param data_output:輸出資料
"""
fmap_block.append(data_output)
input_block.append(data_input)
def forward_pre_hook(module, data_input):
"""
定義前向傳播前的hook函數
:param module: 網絡
:param data_input: 輸入資料
:return:
"""
print("forward_pre_hook input:{}".format(data_input))
def backward_hook(module, grad_input, grad_output):
"""
定義反向傳播的hook函數
:param module: 網絡
:param grad_input: 輸入梯度
:param grad_output: 輸出梯度
:return:
"""
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
# 初始化網絡
net = Net()
# 第一個卷積核全設定為1
net.conv1.weight[0].detach().fill_(1)
# 第二個卷積核全設定為2
net.conv1.weight[1].detach().fill_(2)
# bias不考慮
net.conv1.bias.data.detach().zero_()
# 注冊hook
fmap_block = list()
input_block = list()
# 給卷積層注冊前向傳播hook函數
net.conv1.register_forward_hook(forward_hook)
# 給卷積層注冊前向傳播前的hook函數
net.conv1.register_forward_pre_hook(forward_pre_hook)
# 給卷積層注冊反向傳播的hook函數
net.conv1.register_backward_hook(backward_hook)
# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)
loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()
# 觀察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))
forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]]),)
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]],
[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output:(tensor([[[[0.5000, 0.0000],
[0.0000, 0.0000]],
[[0.5000, 0.0000],
[0.0000, 0.0000]]]]),)
output shape: torch.Size([1, 2, 1, 1])
output value: tensor([[[[ 9.]],
[[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)
feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9., 9.],
[ 9., 9.]],
[[18., 18.],
[18., 18.]]]], grad_fn=<MkldnnConvolutionBackward>)
input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]]),)
如果對您有幫助,麻煩點贊關注,這真的對我很重要!!!如果需要互關,請評論或者私信!