在Pytorch中,
index_select
和
gather
均是被用于張量選取的常用函數,本文通過執行個體來對比這兩個函數。
1. index_select
沿着張量的某個
dim
方向,按照
index
規定的選取指定的低一次元張量元素整體,在拼接成一個張量。其官方解釋如下:
torch.index_select(input, dim, index, out=None)
"""
Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor
"""
先簡單看兩個示例:
示例1:沿着
dim=0
的方向進行
import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a為tensor([[1, 2, 3],
# [4, 5, 6]])
b = torch.index_select(a, dim=0, index=torch.tensor([0,1,0,1]))
# b為tensor([[1, 2, 3],
# [4, 5, 6],
# [1, 2, 3],
# [4, 5, 6]])
![](https://img.laitimes.com/img/__Qf2AjLwojIjJCLyojI0JCLiAzNfRHLGZkRGZkRfJ3bs92YsYTMfVmepNHL9sGSi1WOWFGbaJjYxQmMMBjVtJWd0ckW65UbM5WOHJWa5kHT20ESjBjUIF2X0hXZ0xCMx81dvRWYoNHLrdEZwZ1Rh5WNXp1bwNjW1ZUba9VZwlHdssmch1mclRXY39CXldWYtlWPzNXZj9mcw1ycz9WL49zZuBnL0cTN1UDMzYTM0IjMwAjMwIzLc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.png)
顯然,對于二維張量,
dim=0
意味着按照
index
的編号選取指定的行,拼接成目标張量。其傳回值仍保持和原始張量相同的
ndim
。
示例2:沿着
dim=1
的方向進行
import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a為tensor([[1, 2, 3],
# [4, 5, 6]])
b = torch.index_select(a, dim=1, index=torch.tensor([1,1]))
# b為tensor([[2, 2],
# [5, 5]])
對于二維張量,
dim=1
意味着按照
index
的編号選取指定的列,拼接成目标張量。其傳回值仍保持和原始張量相同的
ndim
。
根據上述兩個例子,可見
index_select
的作用間接明了,即選取某個
dim
上的若幹個元素,将其拼接為目标張量。其中
index
為一個一維張量,表明該
dim
上做選取的具體元素,傳回張量與原張量的
ndim
一緻。
2. gather
相較于
index_select
,
gather
就顯得讓人難以了解的多。個人了解,其操作相當于用于沿着張量的某個
dim
方向,按照
index
規定的選取指定元素,構成該為次元上的每個子張量,最後拼接成一個張量。其官方解釋如下:
torch.gather(input, dim, index, out=None, sparse_grad=False)
"""
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
"""
是不是還是令人費解?我們先以兩個2維張量的例子來說明:
示例1:沿着
dim=1
的方向進行選擇
import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a為tensor([[1, 2, 3],
# [4, 5, 6]])
b = torch.gather(input=a, dim=1, index=torch.tensor([[2,0,2,1], [1,1,0,0]]))
# 傳回值為 tensor([[3, 1, 3, 2],
# [5, 5, 4, 4]])
其操作過程可參照下圖:
由上圖可見,
dim=1
表示在二維張量中,以行為機關,對每行中的元素,按照
index
的索引号進行選取,再拼接到一起。從張量
shape
上看,其在
dim=0
上保持一緻,對
dim=1
進行了放大或縮小。
對于更一般的張量,
gather
的過程可了解為沿着
dim
維的
size
,對各個子張量進行選取和重新的拼接,是以其傳回值和原始張量的
ndim
是相同的。
示例2:沿着
dim=0
的方向進行選擇
import torch
a = torch.tensor([[1,2,3], [4,5,6]])
# a為tensor([[1, 2, 3],
# [4, 5, 6]])
b = torch.gather(input=a, dim=0, index=torch.tensor([[0, 1, 0], [1,0,1], [0, 0, 0],[1,1,1]]))
# 傳回值為 tensor([[1, 5, 3],
# [4, 2, 6],
# [1, 2, 3],
# [4, 5, 6]])
其操作過程可參照下圖:
對于二維張量,其操作過程與
dim=1
相反,即以行為機關,對每列中的元素,按照
index
的索引号進行選取,再拼接到一起。從張量
shape
上看,其在
dim=1
上保持一緻,對
dim=0
進行了放大或縮小。
示例3:三維張量的例子
a = torch.arange(24).reshape(2,3,4)
# a為tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
b = torch.gather(a, dim=2, index=torch.tensor([[[2], [1], [0]], [[1], [2], [3]]]))
# b為tensor([[[ 2],
# [5],
# [8]],
# [[13],
# [18],
# [23]]])
簡單解釋下,其選取的
dim=2
,即沿着三維張量最内層的張量進行元素選取和拼接,其隻選取了一次。是以,上述操作可了解為每個最内層選取一個元素。
3. 總結
index_select
和
gather
雖然都可用于張量元素的選取和重塑,主要參數的命名也類似,但其功能截然不同。簡要而言:
(1)
index_select
用于對
dim
方向各子張量的整體選取和拼接,其中的
index
為一維張量;
(2)
gather
用于對
dim
方向各子張量的元素在其它次元方向上的選取和拼接,其中的
index
為與原張量同
ndim
的張量。