天天看點

【Pytorch】index_select和gather函數的對比

在Pytorch中,

index_select

gather

均是被用于張量選取的常用函數,本文通過執行個體來對比這兩個函數。

1. index_select

沿着張量的某個

dim

方向,按照

index

規定的選取指定的低一次元張量元素整體,在拼接成一個張量。其官方解釋如下:

torch.index_select(input, dim, index, out=None) 
"""
Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor
"""
           

先簡單看兩個示例:

示例1:沿着

dim=0

的方向進行

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a為tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.index_select(a, dim=0, index=torch.tensor([0,1,0,1]))
# b為tensor([[1, 2, 3],
#        [4, 5, 6],
#        [1, 2, 3],
#        [4, 5, 6]])
           
【Pytorch】index_select和gather函數的對比

顯然,對于二維張量,

dim=0

意味着按照

index

的編号選取指定的行,拼接成目标張量。其傳回值仍保持和原始張量相同的

ndim

示例2:沿着

dim=1

的方向進行

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a為tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.index_select(a, dim=1, index=torch.tensor([1,1]))
# b為tensor([[2, 2],
#          [5, 5]])
           
【Pytorch】index_select和gather函數的對比

對于二維張量,

dim=1

意味着按照

index

的編号選取指定的列,拼接成目标張量。其傳回值仍保持和原始張量相同的

ndim

根據上述兩個例子,可見

index_select

的作用間接明了,即選取某個

dim

上的若幹個元素,将其拼接為目标張量。其中

index

為一個一維張量,表明該

dim

上做選取的具體元素,傳回張量與原張量的

ndim

一緻。

2. gather

相較于

index_select

gather

就顯得讓人難以了解的多。個人了解,其操作相當于用于沿着張量的某個

dim

方向,按照

index

規定的選取指定元素,構成該為次元上的每個子張量,最後拼接成一個張量。其官方解釋如下:

torch.gather(input, dim, index, out=None, sparse_grad=False)

"""
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
"""
           

是不是還是令人費解?我們先以兩個2維張量的例子來說明:

示例1:沿着

dim=1

的方向進行選擇

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a為tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.gather(input=a, dim=1, index=torch.tensor([[2,0,2,1], [1,1,0,0]]))
# 傳回值為 tensor([[3, 1, 3, 2],
#         [5, 5, 4, 4]])
           

其操作過程可參照下圖:

【Pytorch】index_select和gather函數的對比

由上圖可見,

dim=1

表示在二維張量中,以行為機關,對每行中的元素,按照

index

的索引号進行選取,再拼接到一起。從張量

shape

上看,其在

dim=0

上保持一緻,對

dim=1

進行了放大或縮小。

對于更一般的張量,

gather

的過程可了解為沿着

dim

維的

size

,對各個子張量進行選取和重新的拼接,是以其傳回值和原始張量的

ndim

是相同的。

示例2:沿着

dim=0

的方向進行選擇

import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a為tensor([[1, 2, 3],
#        [4, 5, 6]])

b = torch.gather(input=a, dim=0, index=torch.tensor([[0, 1, 0], [1,0,1], [0, 0, 0],[1,1,1]]))
# 傳回值為 tensor([[1, 5, 3],
#        [4, 2, 6],
#        [1, 2, 3],
#        [4, 5, 6]])
           

其操作過程可參照下圖:

【Pytorch】index_select和gather函數的對比

對于二維張量,其操作過程與

dim=1

相反,即以行為機關,對每列中的元素,按照

index

的索引号進行選取,再拼接到一起。從張量

shape

上看,其在

dim=1

上保持一緻,對

dim=0

進行了放大或縮小。

示例3:三維張量的例子

a = torch.arange(24).reshape(2,3,4)
# a為tensor([[[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]],

#        [[12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23]]])

b = torch.gather(a, dim=2, index=torch.tensor([[[2], [1], [0]], [[1], [2], [3]]]))

# b為tensor([[[ 2],
#            [5],
#            [8]],

#        [[13],
#        [18],
#        [23]]])
           

簡單解釋下,其選取的

dim=2

,即沿着三維張量最内層的張量進行元素選取和拼接,其隻選取了一次。是以,上述操作可了解為每個最内層選取一個元素。

3. 總結

index_select

gather

雖然都可用于張量元素的選取和重塑,主要參數的命名也類似,但其功能截然不同。簡要而言:

(1)

index_select

用于對

dim

方向各子張量的整體選取和拼接,其中的

index

為一維張量;

(2)

gather

用于對

dim

方向各子張量的元素在其它次元方向上的選取和拼接,其中的

index

為與原張量同

ndim

的張量。

繼續閱讀