一.hook的作用
由于pytorch中,训练产生的中间变量会在训练结束后被释放掉,因此想要将这些变量保存下来,需要用到hook函数,hook可以理解为一个外挂函数,挂载在原有函数上.
二.针对tensor的hook
这个用于保存反向传播时候的梯度
flag = 1
if flag:
#定义网络
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
#定义一个空列表,用于保存hook捕捉的梯度
a_grad = list()
#定义hook函数,
def grad_hook(grad):
#将hook捕捉的梯度保存到a_grad中
a_grad.append(grad)
grad *= 2
#return为tensor类型时候,会将tensor数据赋给被挂载的变量;return为None的时候则不操作
return grad*3
#挂载hook函数到tensor变量a上
handle = a.register_hook(grad_hook)
#执行反向传播,这时候在执行反向传播的过程中会执行a的hook函数
y.backward()
# 查看hook保存的梯度
print("w.grad: ", w.grad)
handle.remove()
三.针对网络的hook
共有三种:
forward_pre_hook:记录网络前向传播前的特征图
forward_hook:记录前向传播后的特征图
backward_hook:记录反向传播后的梯度数据
flag = 1
if flag:
#定义网路
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
def forward_pre_hook(module, data_input):
print("forward_pre_hook input:{}".format(data_input))
def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 注册hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)
# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)
loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()
在执行
的时候,实际上是执行了
#---------------------这一段是判断是否有forward_pre_hook,并执行-----------------
def _call_impl(self, *input, **kwargs):
for hook in itertools.chain(
_global_forward_pre_hooks.values(),
self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
#---------------------这一段是真正执行forward-----------------
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
#---------------------这一段是判断是否有forward_hook,并执行-----------------
for hook in itertools.chain(
_global_forward_hooks.values(),
self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
#---------------------这一段是判断是否有backward_hook,并执行-----------------
if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in itertools.chain(
_global_backward_hooks.values(),
self._backward_hooks.values()):
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result