天天看点

关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 feature、gradient,从而诊断神经网络中可能出现的问题,分析网络有效性。

Hook函数机制:不改变主体,实现额外的功能,像一个挂件一样;

Hook函数本身不是本文介绍的重点,网上介绍的文章颇多,本文主要是记录一下笔者在使用hook函数时遇到的一些问题及解决过程。

register_forward_hook

首先看一下一个最简单的使用register_forward_hook的例子:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        out = F.relu(self.conv1(x))     #1 
        out = F.max_pool2d(out, 2)      #2
        out = F.relu(self.conv2(out))   #3
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out
 
features = []
def hook(module, input, output): 
    # module: model.conv2 
    # input :in forward function  [#2]
    # output:is  [#3 self.conv2(out)]
    print('*'*100)
    features.append(output.clone().detach())
    # output is saved  in a list 
 
 
net = LeNet() ## 模型实例化 
x = torch.randn(2, 3, 32, 32) ## input 
handle = net.conv2.register_forward_hook(hook) ## 获取整个Lenet模型 conv2的中间结果
y = net(x)  ## 获取的是 关于 input x 的 conv2 结果 
 

print(features[0].size()) # 即 [#3 self.conv2(out)]
handle.remove() ## hook删除 ,防止多次保存hook内容占用空间
           

输出

****************************************************************************************************
torch.Size([2, 16, 10, 10])
           

形状是我们想要的结果,打印一串*是为了直观地验证hook函数被调用了。

其中conv2的名称,我们可以打印模型的state_dict()来查看自己要的是哪个module

for k in model.state_dict():
    print(k)
           

输出:

conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias
           

我们上面直接拿conv2做例子了。

出现的问题

在实际使用中,我想打印最近的transformer模型alt_gvt_large的位置编码来看一下,但是遇到了问题。

我查看了一下模型中的module,找到自己想要的

import torch
import timm
import numpy as np
import cv2
import seaborn as sns
import gvt
from PIL import Image
from torchvision import transforms

fmap_block = []
def forward_hook(module, data_input, data_output):
    print('*'*100)
    fmap_block.append(data_output.clone().detach())

model = timm.create_model(
        'alt_gvt_large',
        pretrained=False,
        num_classes=1000,
        drop_rate=0.1,
        drop_path_rate=0.1,
        drop_block_rate=None,
    )
pipeline = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    ])

for k in model.state_dict():
    print(k)
           

输出:

# ...
patch_embeds.3.norm.weight
patch_embeds.3.norm.bias
norm.weight
norm.bias
head.weight
head.bias
pos_block.0.proj.0.weight
pos_block.0.proj.0.bias
pos_block.1.proj.0.weight
pos_block.1.proj.0.bias
pos_block.2.proj.0.weight
pos_block.2.proj.0.bias
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
blocks.0.0.norm1.weight
blocks.0.0.norm1.bias
# ...
           

那肯定就是pos_block喽。

开始hook:

image = Image.open('125.jpg')
image = pipeline(image).unsqueeze(dim=0)

handle = model.pos_block.register_forward_hook(forward_hook)
  
pred = model(image)
print(fmap_block[0].shape)
handle.remove()
           

出大问题,根本没有输出,连我们设置来验证hook函数运行的*也没有出现,hook函数肯定没有被执行,这是怎么回事呢?

解决过程

经过仔细比对以上两次成功和失败hook经历:

conv2.bias
conv2.weight
--------
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
           

简单分析不难有如此猜测:只有下面直接能点( . )到weight和bias的module才能被直接hook。

但是直接将输出结果粘贴过去会出现:

直接报语法错误,数字肯定是不能直接点的。

handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)
                            ^
SyntaxError: invalid syntax
           

于是笔者一层一层查看进去:

for k in model.pos_block:
    print(k)
    for _k in k.proj.state_dict():
        print(_k)
        break
    break 
print(type(model.pos_block))
           

发现上面出现数字的地方的类型其实是:<class ‘torch.nn.modules.container.ModuleList’>,也就是一个list,那是不是直接可以用[ ]进行索引。

于是我们可以改为:

输出:

****************************************************************************************************
torch.Size([1, 256, 28, 28])
           

终于成功。

总结

还是对PyTorch中的Model,Module,childeren_module等理解的不到位啊,只会最基本的使用方法,稍微进阶一点的操作就会遇到阻力,以后有时间梳理一下。PyTorch是当今公认比较好用的开源框架了,但是想要随心所欲地实现自己的想法,还是需要花点时间把其中的各个组件及相互之间的关系都理解到位。

继续阅读