天天看點

關于PyTorch中的register_forward_hook()函數未能執行其中hook函數的問題關于PyTorch中的register_forward_hook()函數未能執行其中hook函數的問題

關于PyTorch中的register_forward_hook()函數未能執行其中hook函數的問題

Hook 是 PyTorch 中一個十分有用的特性。利用它,我們可以不必改變網絡輸入輸出的結構,友善地擷取、改變網絡中間層變量的值和梯度。這個功能被廣泛用于可視化神經網絡中間層的 feature、gradient,進而診斷神經網絡中可能出現的問題,分析網絡有效性。

Hook函數機制:不改變主體,實作額外的功能,像一個挂件一樣;

Hook函數本身不是本文介紹的重點,網上介紹的文章頗多,本文主要是記錄一下筆者在使用hook函數時遇到的一些問題及解決過程。

register_forward_hook

首先看一下一個最簡單的使用register_forward_hook的例子:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        out = F.relu(self.conv1(x))     #1 
        out = F.max_pool2d(out, 2)      #2
        out = F.relu(self.conv2(out))   #3
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out
 
features = []
def hook(module, input, output): 
    # module: model.conv2 
    # input :in forward function  [#2]
    # output:is  [#3 self.conv2(out)]
    print('*'*100)
    features.append(output.clone().detach())
    # output is saved  in a list 
 
 
net = LeNet() ## 模型執行個體化 
x = torch.randn(2, 3, 32, 32) ## input 
handle = net.conv2.register_forward_hook(hook) ## 擷取整個Lenet模型 conv2的中間結果
y = net(x)  ## 擷取的是 關于 input x 的 conv2 結果 
 

print(features[0].size()) # 即 [#3 self.conv2(out)]
handle.remove() ## hook删除 ,防止多次儲存hook内容占用空間
           

輸出

****************************************************************************************************
torch.Size([2, 16, 10, 10])
           

形狀是我們想要的結果,列印一串*是為了直覺地驗證hook函數被調用了。

其中conv2的名稱,我們可以列印模型的state_dict()來檢視自己要的是哪個module

for k in model.state_dict():
    print(k)
           

輸出:

conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias
           

我們上面直接拿conv2做例子了。

出現的問題

在實際使用中,我想列印最近的transformer模型alt_gvt_large的位置編碼來看一下,但是遇到了問題。

我檢視了一下模型中的module,找到自己想要的

import torch
import timm
import numpy as np
import cv2
import seaborn as sns
import gvt
from PIL import Image
from torchvision import transforms

fmap_block = []
def forward_hook(module, data_input, data_output):
    print('*'*100)
    fmap_block.append(data_output.clone().detach())

model = timm.create_model(
        'alt_gvt_large',
        pretrained=False,
        num_classes=1000,
        drop_rate=0.1,
        drop_path_rate=0.1,
        drop_block_rate=None,
    )
pipeline = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    ])

for k in model.state_dict():
    print(k)
           

輸出:

# ...
patch_embeds.3.norm.weight
patch_embeds.3.norm.bias
norm.weight
norm.bias
head.weight
head.bias
pos_block.0.proj.0.weight
pos_block.0.proj.0.bias
pos_block.1.proj.0.weight
pos_block.1.proj.0.bias
pos_block.2.proj.0.weight
pos_block.2.proj.0.bias
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
blocks.0.0.norm1.weight
blocks.0.0.norm1.bias
# ...
           

那肯定就是pos_block喽。

開始hook:

image = Image.open('125.jpg')
image = pipeline(image).unsqueeze(dim=0)

handle = model.pos_block.register_forward_hook(forward_hook)
  
pred = model(image)
print(fmap_block[0].shape)
handle.remove()
           

出大問題,根本沒有輸出,連我們設定來驗證hook函數運作的*也沒有出現,hook函數肯定沒有被執行,這是怎麼回事呢?

解決過程

經過仔細比對以上兩次成功和失敗hook經曆:

conv2.bias
conv2.weight
--------
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
           

簡單分析不難有如此猜測:隻有下面直接能點( . )到weight和bias的module才能被直接hook。

但是直接将輸出結果粘貼過去會出現:

直接報文法錯誤,數字肯定是不能直接點的。

handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)
                            ^
SyntaxError: invalid syntax
           

于是筆者一層一層檢視進去:

for k in model.pos_block:
    print(k)
    for _k in k.proj.state_dict():
        print(_k)
        break
    break 
print(type(model.pos_block))
           

發現上面出現數字的地方的類型其實是:<class ‘torch.nn.modules.container.ModuleList’>,也就是一個list,那是不是直接可以用[ ]進行索引。

于是我們可以改為:

輸出:

****************************************************************************************************
torch.Size([1, 256, 28, 28])
           

終于成功。

總結

還是對PyTorch中的Model,Module,childeren_module等了解的不到位啊,隻會最基本的使用方法,稍微進階一點的操作就會遇到阻力,以後有時間梳理一下。PyTorch是當今公認比較好用的開源架構了,但是想要随心所欲地實作自己的想法,還是需要花點時間把其中的各個元件及互相之間的關系都了解到位。

繼續閱讀