天天看點

pytorch hook 鈎子

簡介

hook是鈎子,主要作用是不修改主代碼,能通過挂載鈎子實作額外功能。 pytorch中,主體就是forward和backward,而額外的功能就是對模型的變量進行操作,如“提取”特征圖,“提取”非葉子張量的梯度,修改張量梯度等等。hook功能即不必改變網絡輸入輸出的結構,就能友善地擷取、改變網絡中間層變量的值和梯度。這個功能被廣泛用于可視化神經網絡中間層的 feature、gradient。

tensor的hook;對module的前向、反向hook,一般來說共有三種hook。

下面的計算圖中,x y w 為葉子節點,而 z 為中間變量

pytorch hook 鈎子

pytorch的計算圖中隻有輸出對葉子結點變量的梯度被儲存下來, 所有中間變量的梯度隻被用于反向傳播,一旦完成反向傳播,中間變量的梯度就将自動釋放(雖然 requires_grad 的參數都是 True),進而節約記憶體。 擷取中間節點梯度還可以用 retain_grad(),但這樣也會會增加記憶體占用。

torch.Tensor.register_hook(hook_fn)

hook_fn(grad) -> Tensor or None ,其中grad就是這個tensor的梯度。

hook_fn是我們自定義的函數,假設對上圖中間節點z的hook_fn函數來說,輸入為變量 z 的梯度,輸出為一個 Tensor 或者是 None (None 一般用于直接列印梯度)。反向傳播時,梯度傳播到變量 z,再繼續向前傳播之前,将會傳入 hook_fn。如果hook_fn的傳回值是 None,那麼梯度将不改變,繼續向前傳播,如果 hook_fn的傳回值是 Tensor 類型,則該 Tensor 将取代 z 原有的梯度,向前傳播。改變中間變量的梯度,之前變量的梯度也會收到影響(變量x,y)。

功能:注冊一個反向傳播hook_fn函數,這個函數是Tensor類裡的,當計算tensor的梯度時自動執行。

為什麼是backward?因為這個hook是針對tensor的,tensor中的什麼東西會在計算結束後釋放呢?隻有gradient嘛,是以是 backward hook. 

應用場景舉例:在hook_fn函數中可對梯度grad進行in-place操作,即可修改tensor的grad值。

下面是一個儲存中間節點grad的簡單例子:

(注意修改中間節點的梯度後,該節點之前變量的梯度也會受到鍊式法則的影響)

import torch

def grad_hook(grad):
    y_grad.append(grad)

y_grad = list()
x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
y = x+1
y.register_hook(grad_hook)

out = torch.mean(y*y)
out.backward()
print("type(y): ", type(y))
print("y.grad: ", y.grad)
print("y_grad[0]: ", y_grad[0])


>>> ('type(y): ', <class 'torch.Tensor'>)
>>> ('y.grad: ', None)
>>> ('y_grad[0]: ', tensor([[1.0000, 1.5000],
        [2.0000, 2.5000]]))      

上述代碼中,x是葉子結點,y是中間節點,反向傳播完成,out對y的梯度y.grad=None 證明中間節點梯度被釋放。

而通過自定義的hook函數:grad_hook 把y的梯度儲存到全局變量:y_grad = list()中。 是以可以在out.backward()結束後,仍舊可以在y_grad[0]中讀到y的梯度為tensor([0.2500, 0.2500, 0.2500, 0.2500])

 下面是一個修改grad的hook:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
    grad *= 2
    return grad*3

handle = w.register_hook(grad_hook)

y.backward()

# 檢視梯度
print("w.grad: ", w.grad)
handle.remove()  
# w.grad:  tensor([30.])=5*2*3      

handle.remove():  a handle that can be used to remove the added hook by calling handle.remove()

在實際代碼中,為了友善,也可以用 lambda 表達式來代替函數,簡寫為如下形式:

torch.Tensor.register_hook(lambda x: 2*x)  # 輸入grad,傳回2*grad,修改梯度值為原來的2倍,注意修改中間節點的梯度後,之前的梯度也會受到鍊式法則的影響

torch.Tensor.register_hook(lambda x: print(x))

一個變量可以綁定多個 hook_fn,反向傳播時,它們按綁定順序依次執行。 

下面介紹網絡子產品的hook:

網絡子產品 module 不像上一節中的 Tensor,擁有顯式的變量名可以直接通路,而是被封裝在神經網絡中間。我們通常隻能獲得網絡整體的輸入和輸出,對于夾在網絡中間的子產品,我們不但很難得知它輸入/輸出的梯度,甚至連它輸入輸出的數值都無法獲得。除非設計網絡時,在 forward 函數的傳回值中包含中間 module 的輸出,或者用很麻煩的辦法,把網絡按照 module 的名稱拆分再組合,讓中間層提取的 feature 暴露出來。

為了解決這個麻煩,PyTorch 設計了兩種 hook:register_forward_hook 和register_backward_hook,分别用來擷取正/反向傳播時,中間層子產品輸入和輸出的 feature/gradient,大大降低了擷取模型内部資訊流的難度。

nn.Module.register_forward_hook(hook_fn)

hook_fn(module, input, output) -> None。注意不能修改input和output 

  Module前向傳播中的hook,  module在前向傳播後,自動調用hook_fn函數。作用是擷取前向傳播過程中,各個網絡子產品的輸入和輸出

  hook_fn函數的輸入變量分别為:子產品,子產品的輸入,子產品的輸出,和對 Tensor 的 hook 不同,forward hook 不傳回任何值,也就是說不能用它來修改輸入或者輸出的值,但借助這個 hook,我們可以友善地用預訓練的神經網絡提取特征,而不用改變預訓練網絡的結構。如 用于提取特征圖。

import torch
from torch import nn

# 首先我們定義一個模型
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(3, 4)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(4, 1)
        self.initialize()
    
    # 為了友善驗證,我們将指定特殊的weight和bias
    def initialize(self):
        with torch.no_grad():
            self.fc1.weight = torch.nn.Parameter(
                torch.Tensor([[1., 2., 3.],
                              [-4., -5., -6.],
                              [7., 8., 9.],
                              [-10., -11., -12.]]))

            self.fc1.bias = torch.nn.Parameter(torch.Tensor([1.0, 2.0, 3.0, 4.0]))
            self.fc2.weight = torch.nn.Parameter(torch.Tensor([[1.0, 2.0, 3.0, 4.0]]))
            self.fc2.bias = torch.nn.Parameter(torch.Tensor([1.0]))

    def forward(self, x):
        o = self.fc1(x)
        o = self.relu1(o)
        o = self.fc2(o)
        return o

# 全局變量,用于存儲中間層的 feature
total_feat_out = []
total_feat_in = []

# 定義 forward hook function
def hook_fn_forward(module, input, output):
    print(module) # 用于區分子產品
    print('input', input) # 首先列印出來
    print('output', output)
    total_feat_out.append(output) # 然後分别存入全局 list 中
    total_feat_in.append(input)


model = Model()

modules = model.named_children() #
for name, module in modules:
    module.register_forward_hook(hook_fn_forward)
    # module.register_backward_hook(hook_fn_backward)

# 注意下面代碼中 x 的次元,對于linear module,輸入一定是大于等于二維的
# (第一維是 batch size)。

x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_() 
o = model(x)
o.backward()

print('==========Saved inputs and outputs==========')
for idx in range(len(total_feat_in)):
    print('input: ', total_feat_in[idx])
    print('output: ', total_feat_out[idx])      

nn.Module.register_backward_hook(hook_fn)

和register_forward_hook相似,register_backward_hook 的作用是擷取神經網絡反向傳播過程中,各個子產品輸入端和輸出端的梯度值。其中hook_fn的函數簽名為:

hook_fn(module, grad_input, grad_output) -> Tensor or None

它的輸入變量分别為:子產品,子產品輸入端的梯度,子產品輸出端的梯度。需要注意的是,這裡的輸入端和輸出端,是站在前向傳播的角度的,而不是反向傳播的角度。例如線性子產品:o=W*x+b,其輸入端為 W,x 和 b,輸出端為 o。 能觀察得到:後一層的grad_input   和前一層的grad_output有關聯(可能相同)

如果子產品有多個輸入或者輸出的話,grad_input和grad_output可以是 tuple 類型。對于線性子產品:o=W*x+b ,它的輸入端包括了W、x 和 b 三部分,是以 grad_input 就是一個包含三個元素的 tuple。

這裡注意和 forward hook 的不同:

1.在 forward hook 中,input 是 x,而不包括 W 和 b。而 backward hook 的 input 則是 b, W, x 三者的梯度。

2.傳回 Tensor 或者 None,backward hook 函數不能直接改變它的輸入變量,但是可以傳回新的 grad_input,反向傳播到它上一個子產品。

 下面是例子:

import torch
from torch import nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(3, 4)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(4, 1)
        self.initialize()

    def initialize(self):
        with torch.no_grad():
            self.fc1.weight = torch.nn.Parameter(
                torch.Tensor([[1., 2., 3.],
                              [-4., -5., -6.],
                              [7., 8., 9.],
                              [-10., -11., -12.]]))

            self.fc1.bias = torch.nn.Parameter(torch.Tensor([1.0, 2.0, 3.0, 4.0]))
            self.fc2.weight = torch.nn.Parameter(torch.Tensor([[10.0, 20.0, 30.0, 40.0]]))
            self.fc2.bias = torch.nn.Parameter(torch.Tensor([2.0]))

    def forward(self, x):
        o = self.fc1(x)
        o = self.relu1(o)
        o = self.fc2(o)
        return o

total_grad_out = []
total_grad_in = []

def hook_fn_backward(module, grad_input, grad_output):
    print(module) # 為了區分子產品
    # 為了符合反向傳播的順序,我們先列印 grad_output
    print('grad_output', grad_output) 
    # 再列印 grad_input
    print('grad_input', grad_input)
    # 儲存到全局變量
    total_grad_in.append(grad_input)
    total_grad_out.append(grad_output)

model = Model()

modules = model.named_children()
for name, module in modules:
    module.register_backward_hook(hook_fn_backward)

# 這裡的 requires_grad 很重要,如果不加,backward hook 執行到第一層,對 x 的導數将為 None,
# 此外再強調一遍 x 的次元,一定不能寫成 torch.Tensor([1.0, 1.0, 1.0]).requires_grad_() 否則 backward hook 會出問題。
x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_()
o = model(x)
o.backward()

print('==========Saved inputs and outputs==========')
for idx in range(len(total_grad_in)):
    print('grad output: ', total_grad_out[idx])
    print('grad input: ', total_grad_in[idx])      

 輸出:

----------------------分割線-----------------------
Linear(in_features=4, out_features=1, bias=True)
grad_output (tensor([[1.]]),)
grad_input (tensor([1.]), tensor([[10., 20., 30., 40.]]), tensor([[ 7.],
        [ 0.],
        [27.],
        [ 0.]]))
ReLU()
grad_output (tensor([[10., 20., 30., 40.]]),)
grad_input (tensor([[10.,  0., 30.,  0.]]),)
Linear(in_features=3, out_features=4, bias=True)
grad_output (tensor([[10.,  0., 30.,  0.]]),)
grad_input (tensor([10.,  0., 30.,  0.]), tensor([[220., 260., 300.]]), tensor([[10.,  0., 30.,  0.],
        [10.,  0., 30.,  0.],
        [10.,  0., 30.,  0.]]))
==========Saved inputs and outputs==========
grad input:  (tensor([1.]), tensor([[10., 20., 30., 40.]]), tensor([[ 7.],
        [ 0.],
        [27.],
        [ 0.]]))
grad input:  (tensor([[10.,  0., 30.,  0.]]),)
grad input:  (tensor([10.,  0., 30.,  0.]), tensor([[220., 260., 300.]]), tensor([[10.,  0., 30.,  0.],
        [10.,  0., 30.,  0.],
        [10.,  0., 30.,  0.]]))      

設z=x*W1+b1,c=ReLu(z), y=c*W2+b2。可以根據鍊式求導法則自行驗證梯度。

backward hook 是按反向傳播順序調用子產品對應的hook,這裡要注意一下。是以結果是先列印fc2、再relu1、最後fc1。ReLu函數的導數(輸入>0?1:0)

 個人了解:

y對 Linear 的 grad_input 則是分别對 b1,x,W1的gradient

y對 Linear 的 grad_output 即是對z的gradient

pytorch hook 鈎子

參考部落格

繼續閱讀