天天看點

圖解pytorch padding方法 ReflectionPad2d

torch.nn.ReflectionPad2d(padding)

這個函數簡單來說就是:利用輸入邊界的反射來填充輸入張量。

官方文檔裡給了該padding的輸入輸出如下所示:

CLASS torch.nn.ReflectionPad2d(padding: Union[T, Tuple[T, T, T, T]])           

複制

  • Input: ( N , C , H i n , W i n ) (N, C, H_{in}, W_{in}) (N,C,Hin​,Win​)
  • Output: ( N , C , H o u t , W o u t ) (N, C, H_{out}, W_{out}) (N,C,Hout​,Wout​)where

    KaTeX parse error: Expected '}', got '_' at position 33: …+ \text{padding_̲top} + \text{pa…

    KaTeX parse error: Expected '}', got '_' at position 33: …+ \text{padding_̲left} + \text{p…

官方文檔比較簡單,我們先直接看代碼:

import torch
import torch.nn as nn

m = nn.ReflectionPad2d(1)
input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
input
-------
tensor([[[[0., 1., 2.],
          [3., 4., 5.],
          [6., 7., 8.]]]])
----------------------------
m(input)
-----
tensor([[[[4., 3., 4., 5., 4.],
          [1., 0., 1., 2., 1.],
          [4., 3., 4., 5., 4.],
          [7., 6., 7., 8., 7.],
          [4., 3., 4., 5., 4.]]]])
------------------------------
n = nn.ReflectionPad2d(2)
n(input)
-------
tensor([[[[8., 7., 6., 7., 8., 7., 6.],
          [5., 4., 3., 4., 5., 4., 3.],
          [2., 1., 0., 1., 2., 1., 0.],
          [5., 4., 3., 4., 5., 4., 3.],
          [8., 7., 6., 7., 8., 7., 6.],
          [5., 4., 3., 4., 5., 4., 3.],
          [2., 1., 0., 1., 2., 1., 0.]]]])
-------------------------------
m = nn.ReflectionPad2d((1, 1, 2, 0))
m(input)
-----
tensor([[[[8., 7., 6., 7., 8., 7., 6.],
          [5., 4., 3., 4., 5., 4., 3.],
          [2., 1., 0., 1., 2., 1., 0.],
          [5., 4., 3., 4., 5., 4., 3.],
          [8., 7., 6., 7., 8., 7., 6.],
          [5., 4., 3., 4., 5., 4., 3.],
          [2., 1., 0., 1., 2., 1., 0.]]]])           

複制

1. 填充一個整數

圖解pytorch padding方法 ReflectionPad2d
圖解pytorch padding方法 ReflectionPad2d

從上面的例子看,填充的值就是對應邊界旁邊的值按填充參數反射過去,即反射填充。

填充順序是左-右-上-下,

nn.ReflectionPad2d((1, 1, 2, 0))

中,這幾個數字表示左右上下分别要填充的層數。我們畫圖表示更清楚,如下:

圖解pytorch padding方法 ReflectionPad2d

幾幅圖已經比較清楚的表達了該padding方法的過程。

Note: padding的大小要小于輸入的大小

Reference:

幾幅圖已經比較清楚的表達了該padding方法的過程。

Note: padding的大小要小于輸入的大小

Reference:

PyTorch: REFLECTIONPAD2D