天天看點

PyTorch學習—16.PyTorch中hook函數

文章目錄

      • 一、Hook函數概念
      • 二、四種Hook函數介紹
        • 1. Tensor.register_hook
        • 2. Module.register_forward_hook
        • 3.Module.register_forward_pre_hook
        • 4. Module.register_backward_hook

一、Hook函數概念

  Hook函數機制:不改變主體,實作額外功能,像一個挂件一樣将功能挂到函數主體上。Hook函數與PyTorch中的動态圖運算機制有關,因為在動态圖計算,在運算結束後,中間變量是會被釋放掉的,例如:非葉子節點的梯度。但是,我們往往想要提取這些中間變量,這時,我們就可以采用Hook函數在前向傳播與反向傳播主體上挂上一些額外的功能(函數),通過這些函數擷取中間的梯度,甚至是改變中間的梯度。PyTorch一共提供了四種Hook函數:

  • torch.Tensor.register_hook(hook)
  • torch.nn.Module.register_forward_hook
  • torch.nn.Module.register_forward_pre_hook
  • torch.nn.Module.register_backward_hook

一種是針對Tensor,其餘三種是針對網絡的

二、四種Hook函數介紹

1. Tensor.register_hook

def register_hook(self, hook):
	"""
	接受一個hook函數
	"""
	...
           

功能:注冊一個反向傳播hook函數,這是因為張量在反向傳播的時候,如果不是葉子節點,它的梯度就會消失。由于反向傳播過程中存在資料的釋放,是以就有了反向傳播的hook函數

  • Hook函數僅一個輸入參數,為張量的梯度

  下面,我們通過計算圖流程來觀察張量梯度的擷取以及熟悉Hook函數。

y = ( x + w ) ∗ ( w + 1 ) y=(x+w)*(w+1) y=(x+w)∗(w+1)

PyTorch學習—16.PyTorch中hook函數
import torch
import torch.nn as nn

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

# 存儲張量的梯度
a_grad = list()

def grad_hook(grad):
    """
    定義一個hook函數,将梯度存儲到清單中
    :param grad: 梯度
    :return:
    """
    a_grad.append(grad)

# 注冊一個反向傳播的hook函數,功能是将梯度存儲到a_grad清單中
handle = a.register_hook(grad_hook)
# 反向傳播
y.backward()

# 檢視梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()
           
tensor([5.]) tensor([2.]) None None None
a_grad[0]:  tensor([2.])
           

如果對葉子節點的張量使用hook函數,那麼會怎麼樣呢?

import torch
import torch.nn as nn

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
    grad *= 2
    return grad*3

handle = w.register_hook(grad_hook)

y.backward()

# 檢視梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("w.grad: ", w.grad)
handle.remove()
           
gradient: tensor([30.]) tensor([2.]) None None None
w.grad:  tensor([30.])
           

與上面比較,發現hook函數相當于對已有張量進行原地操作

2. Module.register_forward_hook

def register_forward_hook(self, hook):
	...
           

功能:注冊module的前向傳播hook函數

參數:

  • module: 目前網絡層
  • input:目前網絡層輸入資料
  • output:目前網絡層輸出資料

3.Module.register_forward_pre_hook

功能:注冊module前向傳播前的hook函數

參數:

  • module: 目前網絡層
  • input:目前網絡層輸入資料

4. Module.register_backward_hook

功能:注冊module反向傳播的hook函數

參數:

  • module: 目前網絡層
  • grad_input:目前網絡層輸入梯度資料
  • grad_output:目前網絡層輸出梯度資料

下面例子展示這三個hook函數

PyTorch學習—16.PyTorch中hook函數
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

def forward_hook(module, data_input, data_output):
    """
    定義前向傳播hook函數
    :param module:網絡
    :param data_input:輸入資料
    :param data_output:輸出資料
    """
    fmap_block.append(data_output)
    input_block.append(data_input)

def forward_pre_hook(module, data_input):
    """
    定義前向傳播前的hook函數
    :param module: 網絡
    :param data_input: 輸入資料
    :return:
    """
    print("forward_pre_hook input:{}".format(data_input))

def backward_hook(module, grad_input, grad_output):
    """
    定義反向傳播的hook函數
    :param module: 網絡
    :param grad_input: 輸入梯度
    :param grad_output: 輸出梯度
    :return:
    """
    print("backward hook input:{}".format(grad_input))
    print("backward hook output:{}".format(grad_output))

# 初始化網絡
net = Net()
# 第一個卷積核全設定為1
net.conv1.weight[0].detach().fill_(1)
# 第二個卷積核全設定為2
net.conv1.weight[1].detach().fill_(2)
# bias不考慮
net.conv1.bias.data.detach().zero_()

# 注冊hook
fmap_block = list()
input_block = list()
# 給卷積層注冊前向傳播hook函數
net.conv1.register_forward_hook(forward_hook)
# 給卷積層注冊前向傳播前的hook函數
net.conv1.register_forward_pre_hook(forward_pre_hook)
# 給卷積層注冊反向傳播的hook函數
net.conv1.register_backward_hook(backward_hook)

# inference
fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
output = net(fake_img)

loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()

# 觀察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))
           
forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]],
        [[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output:(tensor([[[[0.5000, 0.0000],
          [0.0000, 0.0000]],
         [[0.5000, 0.0000],
          [0.0000, 0.0000]]]]),)
output shape: torch.Size([1, 2, 1, 1])
output value: tensor([[[[ 9.]],
         [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)
feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9.,  9.],
          [ 9.,  9.]],
         [[18., 18.],
          [18., 18.]]]], grad_fn=<MkldnnConvolutionBackward>)
input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)
           

如果對您有幫助,麻煩點贊關注,這真的對我很重要!!!如果需要互關,請評論或者私信!

PyTorch學習—16.PyTorch中hook函數

繼續閱讀