天天看点

pytorch中的hook机制

一.hook的作用

由于pytorch中,训练产生的中间变量会在训练结束后被释放掉,因此想要将这些变量保存下来,需要用到hook函数,hook可以理解为一个外挂函数,挂载在原有函数上.

二.针对tensor的hook

这个用于保存反向传播时候的梯度

flag = 1
if flag:
    #定义网络
    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)
    #定义一个空列表,用于保存hook捕捉的梯度
    a_grad = list()
    #定义hook函数,
    def grad_hook(grad):
        #将hook捕捉的梯度保存到a_grad中
        a_grad.append(grad)
        grad *= 2
        #return为tensor类型时候,会将tensor数据赋给被挂载的变量;return为None的时候则不操作
        return grad*3
    #挂载hook函数到tensor变量a上
    handle = a.register_hook(grad_hook)
    #执行反向传播,这时候在执行反向传播的过程中会执行a的hook函数
    y.backward()

    # 查看hook保存的梯度
    print("w.grad: ", w.grad)
    handle.remove()
           

三.针对网络的hook

共有三种:

forward_pre_hook:记录网络前向传播前的特征图

forward_hook:记录前向传播后的特征图

backward_hook:记录反向传播后的梯度数据

flag = 1
if flag:
    #定义网路
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x
    
    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)

    def forward_pre_hook(module, data_input):
        print("forward_pre_hook input:{}".format(data_input))

    def backward_hook(module, grad_input, grad_output):
        print("backward hook input:{}".format(grad_input))
        print("backward hook output:{}".format(grad_output))

    # 初始化网络
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()

    # 注册hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(forward_hook)
    net.conv1.register_forward_pre_hook(forward_pre_hook)
    net.conv1.register_backward_hook(backward_hook)

    # inference
    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
    output = net(fake_img)

    loss_fnc = nn.L1Loss()
    target = torch.randn_like(output)
    loss = loss_fnc(target, output)
    loss.backward()
           

在执行

的时候,实际上是执行了

#---------------------这一段是判断是否有forward_pre_hook,并执行-----------------
    def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
#---------------------这一段是真正执行forward-----------------                
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
#---------------------这一段是判断是否有forward_hook,并执行-----------------            
        for hook in itertools.chain(
                _global_forward_hooks.values(),
                self._forward_hooks.values()):
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
#---------------------这一段是判断是否有backward_hook,并执行-----------------                
        if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in itertools.chain(
                        _global_backward_hooks.values(),
                        self._backward_hooks.values()):
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result
           

继续阅读