關于PyTorch中的register_forward_hook()函數未能執行其中hook函數的問題
Hook 是 PyTorch 中一個十分有用的特性。利用它,我們可以不必改變網絡輸入輸出的結構,友善地擷取、改變網絡中間層變量的值和梯度。這個功能被廣泛用于可視化神經網絡中間層的 feature、gradient,進而診斷神經網絡中可能出現的問題,分析網絡有效性。
Hook函數機制:不改變主體,實作額外的功能,像一個挂件一樣;
Hook函數本身不是本文介紹的重點,網上介紹的文章頗多,本文主要是記錄一下筆者在使用hook函數時遇到的一些問題及解決過程。
register_forward_hook
首先看一下一個最簡單的使用register_forward_hook的例子:
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = F.relu(self.conv1(x)) #1
out = F.max_pool2d(out, 2) #2
out = F.relu(self.conv2(out)) #3
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
features = []
def hook(module, input, output):
# module: model.conv2
# input :in forward function [#2]
# output:is [#3 self.conv2(out)]
print('*'*100)
features.append(output.clone().detach())
# output is saved in a list
net = LeNet() ## 模型執行個體化
x = torch.randn(2, 3, 32, 32) ## input
handle = net.conv2.register_forward_hook(hook) ## 擷取整個Lenet模型 conv2的中間結果
y = net(x) ## 擷取的是 關于 input x 的 conv2 結果
print(features[0].size()) # 即 [#3 self.conv2(out)]
handle.remove() ## hook删除 ,防止多次儲存hook内容占用空間
輸出
****************************************************************************************************
torch.Size([2, 16, 10, 10])
形狀是我們想要的結果,列印一串*是為了直覺地驗證hook函數被調用了。
其中conv2的名稱,我們可以列印模型的state_dict()來檢視自己要的是哪個module
for k in model.state_dict():
print(k)
輸出:
conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias
我們上面直接拿conv2做例子了。
出現的問題
在實際使用中,我想列印最近的transformer模型alt_gvt_large的位置編碼來看一下,但是遇到了問題。
我檢視了一下模型中的module,找到自己想要的
import torch
import timm
import numpy as np
import cv2
import seaborn as sns
import gvt
from PIL import Image
from torchvision import transforms
fmap_block = []
def forward_hook(module, data_input, data_output):
print('*'*100)
fmap_block.append(data_output.clone().detach())
model = timm.create_model(
'alt_gvt_large',
pretrained=False,
num_classes=1000,
drop_rate=0.1,
drop_path_rate=0.1,
drop_block_rate=None,
)
pipeline = transforms.Compose([
transforms.RandomCrop(224),
transforms.ToTensor(),
])
for k in model.state_dict():
print(k)
輸出:
# ...
patch_embeds.3.norm.weight
patch_embeds.3.norm.bias
norm.weight
norm.bias
head.weight
head.bias
pos_block.0.proj.0.weight
pos_block.0.proj.0.bias
pos_block.1.proj.0.weight
pos_block.1.proj.0.bias
pos_block.2.proj.0.weight
pos_block.2.proj.0.bias
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
blocks.0.0.norm1.weight
blocks.0.0.norm1.bias
# ...
那肯定就是pos_block喽。
開始hook:
image = Image.open('125.jpg')
image = pipeline(image).unsqueeze(dim=0)
handle = model.pos_block.register_forward_hook(forward_hook)
pred = model(image)
print(fmap_block[0].shape)
handle.remove()
出大問題,根本沒有輸出,連我們設定來驗證hook函數運作的*也沒有出現,hook函數肯定沒有被執行,這是怎麼回事呢?
解決過程
經過仔細比對以上兩次成功和失敗hook經曆:
conv2.bias
conv2.weight
--------
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
簡單分析不難有如此猜測:隻有下面直接能點( . )到weight和bias的module才能被直接hook。
但是直接将輸出結果粘貼過去會出現:
直接報文法錯誤,數字肯定是不能直接點的。
handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)
^
SyntaxError: invalid syntax
于是筆者一層一層檢視進去:
for k in model.pos_block:
print(k)
for _k in k.proj.state_dict():
print(_k)
break
break
print(type(model.pos_block))
發現上面出現數字的地方的類型其實是:<class ‘torch.nn.modules.container.ModuleList’>,也就是一個list,那是不是直接可以用[ ]進行索引。
于是我們可以改為:
輸出:
****************************************************************************************************
torch.Size([1, 256, 28, 28])
終于成功。
總結
還是對PyTorch中的Model,Module,childeren_module等了解的不到位啊,隻會最基本的使用方法,稍微進階一點的操作就會遇到阻力,以後有時間梳理一下。PyTorch是當今公認比較好用的開源架構了,但是想要随心所欲地實作自己的想法,還是需要花點時間把其中的各個元件及互相之間的關系都了解到位。