關于tf.gather函數batch_dims參數用法的了解
- 0 前言
- 1. 不考慮batch_dims
- 2. 批處理(考慮batch_dims)
-
- 2.1 batch_dims=1
- 2.2 batch_dims=0
- 2.3 batch_dims>=2
- 2.4 batch_dims再降為1
- 2.5 再将axis降為1
- 2.6 batch_dims<0
- 2.7 batch_dims總結
- 3. 補充
- 4. 參數和傳回值
- 5. 其他相關論述
- 6. 附件
截至發稿(2023年3月2日)之前,全網對這個問題的解釋都不是很清楚(包括官網和英文網際網路),尤其是對
batch_dims
本質實體含義的解釋,以下内容根據
tf.gather
官網進行翻譯,并補充。
0 前言
根據索引
indices
從參數
axis
軸收集切片。 (棄用的參數,應該指下文的
validate_indices
)
tf.gather(
params, indices, validate_indices=None, axis=None, batch_dims=0, name=None
)
已棄用:一些參數已棄用:(
validate_indices
)。 它們将在未來的版本中被删除。 更新說明:
validate_indices
參數無效。 索引(indices)總是在 CPU 上驗證,從不在 GPU 上驗證。
1. 不考慮batch_dims
根據索引
indices
從軸參數
axis
收集切片。
indices
必須是任意次元(通常是1-D)的整數張量。
Tensor.getitem
适用于标量、
tf.newaxis
和 python切片
tf.gather
擴充索引功能以處理索引(indices)張量。
在最簡單的情況下,它與标量索引功能相同:
>>> params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
>>> params[3].numpy()
b'p3'
>>> tf.gather(params, 3).numpy()
b'p3'
最常見的情況是傳遞索引的單軸張量(這不能表示為python切片,因為索引不是連續的):
>>> indices = [2, 0, 2, 5]
>>> tf.gather(params, indices).numpy()
array([b'p2', b'p0', b'p2', b'p5'], dtype=object)
過程如下圖所示:
索引可以有任何形狀(shape)。 當參數
params
有 1 個軸(axis)時,輸出形狀等于輸入形狀:
>>> tf.gather(params, [[2, 0], [2, 5]]).numpy()
array([[b'p2', b'p0'],
[b'p2', b'p5']], dtype=object)
參數
params
也可以有任何形狀。
gather
可以根據參數
axis
(預設為 0)在任何軸(axis)上選擇切片。 它下面例程用于收集(gather)矩陣中的第一行,然後是列:
>>> params = tf.constant([[0, 1.0, 2.0],
... [10.0, 11.0, 12.0],
... [20.0, 21.0, 22.0],
... [30.0, 31.0, 32.0]])
>>> tf.gather(params, indices=[3,1]).numpy()
array([[30., 31., 32.],
[10., 11., 12.]], dtype=float32)
>>> tf.gather(params, indices=[2,1], axis=1).numpy()
array([[ 2., 1.],
[12., 11.],
[22., 21.],
[32., 31.]], dtype=float32)
更一般地說:輸出形狀與輸入形狀相同,索引軸(indexed-axis)由索引(indices)的形狀代替。
>>> def result_shape(p_shape, i_shape, axis=0):
... return p_shape[:axis] + i_shape + p_shape[axis+1:]
>>>
>>> result_shape([1, 2, 3], [], axis=1)
[1, 3]
>>> result_shape([1, 2, 3], [7], axis=1)
[1, 7, 3]
>>> result_shape([1, 2, 3], [7, 5], axis=1)
[1, 7, 5, 3]
例如下面的例程:
>>> params.shape.as_list()
[4, 3]
>>> indices = tf.constant([[0, 2]])
>>> tf.gather(params, indices=indices, axis=0).shape.as_list()
[1, 2, 3]
>>> tf.gather(params, indices=indices, axis=1).shape.as_list()
[4, 1, 2]
>>> params = tf.random.normal(shape=(5, 6, 7, 8))
>>> indices = tf.random.uniform(shape=(10, 11), maxval=7, dtype=tf.int32)
>>> result = tf.gather(params, indices, axis=2)
>>> result.shape.as_list()
[5, 6, 10, 11, 8]
這是因為每個索引都從
params
中擷取一個切片,并将其放置在輸出中的相應位置。 對于上面的例子
>>> # For any location in indices
>>> a, b = 0, 1
>>> tf.reduce_all(
... # the corresponding slice of the result
... result[:, :, a, b, :] ==
... # is equal to the slice of `params` along `axis` at the index.
... params[:, :, indices[a, b], :]
... ).numpy()
True
除此之外,我們再給
indices
增加一個元素,當進行
gather
的時候是沿着
params
的
axis=1
的上一個次元的元素進行循環的。即
params
的
axis=0
的元素分别為
[0, 1.0, 2.0]
、
[10.0, 11.0, 12.0]
、
[20.0, 21.0, 22.0]
、
[30.0, 31.0, 32.0]
,然後逐次對這四個元素裡面的
params
的
axis=1
的元素進行取
indices
對應的元素,四次循環完成整個
gather
>>> tf.gather(params, indices=[[2,1], [1,0]], axis=1).numpy()
array([[[ 2., 1.],
[ 1., 0.]],
[[12., 11.],
[11., 10.]],
[[22., 21.],
[21., 20.]],
[[32., 31.],
[31., 30.]]], dtype=float32)
2. 批處理(考慮batch_dims)
batch_dims
參數可以讓您從批次的每個元素中收集不同的項目。
ps:
可以先直接跳到到2.7 batch_dims總結,前後對照閱讀。
2.1 batch_dims=1
使用
batch_dims=1
相當于在
params
和
indices
的第一個軸(是指
axis=0
軸)上有一個外循環(在
axis=0
軸上的元素上進行循環):
>>> params = tf.constant([
... [0, 0, 1, 0, 2],
... [3, 0, 0, 0, 4],
... [0, 5, 0, 6, 0]])
>>> indices = tf.constant([
... [2, 4],
... [0, 4],
... [1, 3]])
>>> tf.gather(params, indices, axis=1, batch_dims=1).numpy()
array([[1, 2],
[3, 4],
[5, 6]], dtype=int32)
等價于:
>>> def manually_batched_gather(params, indices, axis):
... batch_dims=1
... result = []
... for p,i in zip(params, indices): # 這就是上文所說的外循環
... r = tf.gather(p, i, axis=axis-batch_dims)
... result.append(r)
... return tf.stack(result)
>>> manually_batched_gather(params, indices, axis=1).numpy()
array([[1, 2],
[3, 4],
[5, 6]], dtype=int32)
接下來将循環裡
zip
的結果列印如下,說明外循環将
params
和
indices
在第一個軸上先zip成三個元組
pprint(list(zip(params.numpy().tolist(), indices.numpy().tolist())))
# [([0, 0, 1, 0, 2], [2, 4]),
# ([3, 0, 0, 0, 4], [0, 4]),
# ([0, 5, 0, 6, 0], [1, 3])]
然後分别對
[0, 0, 1, 0, 2]
與
[2, 4]
、
[3, 0, 0, 0, 4]
與
[0, 4]
、
[0, 5, 0, 6, 0]
與
[1, 3]
,沿着重組之後的
axis = 0
(即重組之前的
axis = 1
,這就是為什麼後面所說的必須
axis
>=
batch_dims
)進行
gather
。
2.2 batch_dims=0
是以可以總結:
batch_dims
是指最終對哪一個次元的張量進行對照
gather
,是以當
batch_dims=0
時,實際上就是将兩個整個張量組包,也就是上面第一階段的省略
batch_dims
的狀态。
此時,相當于将兩個張量在外面添加一個次元之後再
zip
,相當于沒
zip
直接
gather
。是以,以下兩條指令等價,因為
batch_dims
預設值為
。
params = tf.constant([[ # 相對于上文該張量增加了一個次元
[0, 0, 1, 0, 2],
[3, 0, 0, 0, 4],
[0, 5, 0, 6, 0]]])
indices = tf.constant([[ # 相對于上文該張量增加了一個次元
[2, 4],
[0, 4],
[1, 3]]])
pprint(list(zip(params.numpy().tolist(), indices.numpy().tolist())))
# [([[0, 0, 1, 0, 2], [3, 0, 0, 0, 4], [0, 5, 0, 6, 0]],
# [[2, 4], [0, 4], [1, 3]])]
tf.gather(params, indices, axis=1, batch_dims=0).numpy()
# 等價于
tf.gather(params, indices, axis=1).numpy()
# 輸出結果為
# array([[[1, 2],
# [0, 2],
# [0, 0]],
#
# [[0, 4],
# [3, 4],
# [0, 0]],
#
# [[0, 0],
# [0, 0],
# [5, 6]]], dtype=int32)
2.3 batch_dims>=2
較高的
batch_dims
值相當于在
params
和
indices
的外軸上進行多個嵌套循環。 是以整體形狀函數是
>>> def batched_result_shape(p_shape, i_shape, axis=0, batch_dims=0):
... return p_shape[:axis] + i_shape[batch_dims:] + p_shape[axis+1:]
>>> batched_result_shape(
... p_shape=params.shape.as_list(),
... i_shape=indices.shape.as_list(),
... axis=1,
... batch_dims=1)
[3, 2]
>>> tf.gather(params, indices, axis=1, batch_dims=1).shape.as_list()
[3, 2]
舉例來說,
params
和
indices
升高一個次元,即
batch_dims=2
,這時按照限制條件隻能
axis=2
params = tf.constant([ # 升高一個次元
[[0, 0, 1, 0, 2],
[3, 0, 0, 0, 4],
[0, 5, 0, 6, 0]],
[[1, 8, 4, 2, 2],
[9, 6, 2, 3, 0],
[7, 2, 8, 6, 3]]])
indices = tf.constant([ # 升高一個次元
[[2, 4],
[0, 4],
[1, 3]],
[[1, 3],
[2, 1],
[4, 2]]])
# 進行batch_dims高值gather計算
tf.gather(params, indices, axis=2, batch_dims=2).numpy()
# 則上面的運算等價于
def manually_batched_gather_3d(params, indices, axis):
batch_dims=2
result = []
for p,i in zip(params, indices): # 這裡面進行了batch_dims層(也就是2層)嵌套for循環
result_2 = []
for p_2, i_2 in zip(p,i):
r = tf.gather(p_2, i_2, axis=axis-batch_dims) # 這裡告訴我們為什麼axis必須>=batch_dims
result_2.append(r)
result.append(result_2)
return tf.stack(result)
manually_batched_gather_3d(params, indices, axis=2).numpy()
# array([[[1, 2],
# [3, 4],
# [5, 6]],
#
# [[8, 2],
# [2, 6],
# [3, 8]]], dtype=int32)
下面來解釋一下上面程式的運作過程,在上面的
manually_batched_gather_3d
運作過程中第一層
zip
的作用如下
pprint(list(zip(params.numpy().tolist(), indices.numpy().tolist())))
# 列印得到如下list,該list有兩個元組組成,都是将兩個參數的axis=0軸上的兩個二維張量,分别進行了組包
# [([[0, 0, 1, 0, 2],
# [3, 0, 0, 0, 4],
# [0, 5, 0, 6, 0]], # 到這兒為params的axis=0軸上的[0]二維張量
# [[2, 4],
# [0, 4],
# [1, 3]]), # 到這兒為indices的axis=0軸上的[0]二維張量
#
# ([[1, 8, 4, 2, 2],
# [9, 6, 2, 3, 0],
# [7, 2, 8, 6, 3]], # 到這兒為params的axis=0軸上的[1]二維張量
# [[1, 3],
# [2, 1],
# [4, 2]])] # 到這兒為indices的axis=0軸上的[1]二維張量
然後進入第一層for循環的第一次循環,将
zip
之後的兩個元組中的第一個元組,拿過來分别賦給
p
、
i
:
p=tf.Tensor(
[[0 0 1 0 2]
[3 0 0 0 4]
[0 5 0 6 0]], shape=(3, 5), dtype=int32)
i=tf.Tensor(
[[2 4]
[0 4]
[1 3]], shape=(3, 2), dtype=int32)
在第二層
for
之前插入,得到第二層的
zip
結果
print(list(zip(p.numpy().tolist(), i.numpy().tolist())))
# [([0, 0, 1, 0, 2], [2, 4]),
# ([3, 0, 0, 0, 4], [0, 4]),
# ([0, 5, 0, 6, 0], [1, 3])]
則開始第二層for的第一次循環,則
# p_2 = tf.Tensor([0 0 1 0 2], shape=(5,), dtype=int32)
# i_2 = tf.Tensor([2 4], shape=(2,), dtype=int32)
# r = tf.Tensor([1 2], shape=(2,), dtype=int32)
這之後第二層for循環再進行2次循環,退回到第一層大循環,第一層大循環再進行一次上述循環即完成了整個循環。
2.4 batch_dims再降為1
你會發現,下面兩條指令等價,即
batch_dims=1
隻有一層循環,隻
zip
一次
tf.gather(params, indices, axis=2, batch_dims=1).numpy()
# 等價于
manually_batched_gather(params, indices, axis=2).numpy()
# [[[[1 2]
# [0 2]
# [0 0]]
#
# [[0 4]
# [3 4]
# [0 0]]
#
# [[0 0]
# [0 0]
# [5 6]]]
#
#
# [[[8 2]
# [4 8]
# [2 4]]
#
# [[6 3]
# [2 6]
# [0 2]]
#
# [[2 6]
# [8 2]
# [3 8]]]]
2.5 再将axis降為1
還需修改一下
indices
,因為下文有對
indices
的限制——必須在
[0, params.shape[axis]]
範圍内,此時
params.shape
為
(2, 3, 5)
,則
params.shape[1]=3
,是以
indices
隻能等于
或
1
或
2
,如果>=3索引的時候就會溢出。此時還是
batch_dims=1
隻有一層循環,隻
zip
一次,隻是改變了索引軸。
indices = tf.constant([
[[1, 0],
[2, 1],
[2, 0]],
[[2, 0],
[0, 1],
[1, 2]]])
tf.gather(params, indices, axis=1, batch_dims=1).numpy()
# 等價于
manually_batched_gather(params, indices, axis=1).numpy()
# array([[[[3, 0, 0, 0, 4],
# [0, 0, 1, 0, 2]],
#
# [[0, 5, 0, 6, 0],
# [3, 0, 0, 0, 4]],
#
# [[0, 5, 0, 6, 0],
# [0, 0, 1, 0, 2]]],
#
#
# [[[7, 2, 8, 6, 3],
# [1, 8, 4, 2, 2]],
#
# [[1, 8, 4, 2, 2],
# [9, 6, 2, 3, 0]],
#
# [[9, 6, 2, 3, 0],
# [7, 2, 8, 6, 3]]]], dtype=int32)>>
2.6 batch_dims<0
因為
params
和
indices
一共由3各次元——
、
1
、
2
,其對應的負次元就是
-3
、
-2
、
-1
,是以下面兩條指令等價
a = tf.gather(params, indices, axis=2, batch_dims=1).numpy()
pprint(a)
# 等價于
a = tf.gather(params, indices, axis=2, batch_dims=-2).numpy()
pprint(a)
2.7 batch_dims總結
故個人認為,
batch_dims
是由batch和dimensions兩個單詞縮寫而成,因為dimensions為複數是以可以翻譯為“批量次元數”(自己翻譯沒有查到文獻),可以指批處理
batch_dims
個次元,如果是正數可以了解成嵌套幾層循環或者進行幾次
zip
,如果是負數需要轉化為對應的正次元再進行上述了解;也可以是指組包到哪一個次元上,如果是負數也同樣适用于這種解釋。
batch_dims
極大的擴充了
gather
的功能,使你可以将
params
和
indices
在對應的某個次元上分别進行
gather
然後再
stack
。
ps:關于
batch_dims
的這個解釋同樣也适用于tf.gather_nd。
3. 補充
如果您需要使用諸如 tf.argsort 或 tf.math.top_k 之類的操作的索引,其中索引的最後一個次元在相應位置索引到輸入的最後一個次元,這自然會出現。 在這種情況下,您可以使用 tf.gather(values, indices, batch_dims=-1)。
4. 參數和傳回值
參數 | |
---|---|
| 從中收集值的 (張量)。其秩(rank)必須至少為 + 1。 |
| 索引張量。 必須是以下類型之一: 、 。 這些值必須在 範圍内。 |
| 已棄用,沒有任何作用。 索引總是在 CPU 上驗證,從不在 GPU 上驗證。 注意:在 CPU 上,如果發現越界索引,則會引發錯誤。 在 GPU 上,如果發現越界索引,則将 0 存儲在相應的輸出值中。 |
| 一個 ((張量))。 必須是以下類型之一: 、 。 從參數 中的 軸收集索引。 必須大于或等于 。 預設為第一個**非批次次元 **。 支援負索引。 |
| 一個 (整數)。 批量次元(batch dimensions)的數量。 必須小于或等于 。 |
| 操作的名稱(可選)。 |
傳回值 |
---|
一個 (張量), 與 具有相同的類型。 |
5. 其他相關論述
下面幾篇部落格,相對于官網手冊都有新的資訊增量,可以作為參考
- 知網《tf.gather()函數》,使用索引推演的方式在次元和操作兩個方面進行了解,但是其關于
的描述不夠充分且有些片面;batch_dims
- 知乎《tf.gather()函數總結》,舉了一個新的例子,但是
還是隻到了1,沒有很好的歸納其真正的實體意義;batch_dims
- CSDN《tf.gather函數》,跟上一篇的情況差不多。
6. 附件
上文用到的調試程式,
可以忽略
import tensorflow as tf
from pprint import pprint
params = tf.constant([[0, 1.0, 2.0],
[10.0, 11.0, 12.0],
[20.0, 21.0, 22.0],
[30.0, 31.0, 32.0]])
a = tf.gather(params, indices=[[2,1], [1,0]], axis=1).numpy()
pprint(a)
params = tf.constant([
[0, 0, 1, 0, 2],
[3, 0, 0, 0, 4],
[0, 5, 0, 6, 0]])
indices = tf.constant([
[2, 4],
[0, 4],
[1, 3]])
a = tf.gather(params, indices, axis=1, batch_dims=1).numpy()
pprint(a)
a = tf.gather(params, indices, axis=1, batch_dims=-1).numpy()
pprint(a)
def manually_batched_gather(params, indices, axis):
batch_dims=1
result = []
for p,i in zip(params, indices):
r = tf.gather(p, i, axis=axis-batch_dims)
result.append(r)
return tf.stack(result)
manually_batched_gather(params, indices, axis=1).numpy()
pprint(list(zip(params.numpy().tolist(), indices.numpy().tolist())))
tf.gather(params, indices, axis=1, batch_dims=0).numpy()
tf.gather(params, indices, axis=1).numpy()
# tf.gather(params, indices, axis=0, batch_dims=0).numpy()
params = tf.constant([[
[0, 0, 1, 0, 2],
[3, 0, 0, 0, 4],
[0, 5, 0, 6, 0]]])
indices = tf.constant([[
[2, 4],
[0, 4],
[1, 3]]])
pprint(list(zip(params.numpy().tolist(), indices.numpy().tolist())))
# [([[0, 0, 1, 0, 2], [3, 0, 0, 0, 4], [0, 5, 0, 6, 0]],
# [[2, 4], [0, 4], [1, 3]])]
params_1 = [[0, 0, 1, 0, 2],
[3, 0, 0, 0, 4],
[0, 5, 0, 6, 0]],
indices_1 = [[2, 4],
[0, 4],
[1, 3]]
# a = tf.gather(params_1, indices_1, axis=0).numpy()
params = tf.constant([
[[0, 0, 1, 0, 2],
[3, 0, 0, 0, 4],
[0, 5, 0, 6, 0]],
[[1, 8, 4, 2, 2],
[9, 6, 2, 3, 0],
[7, 2, 8, 6, 3]]])
indices = tf.constant([
[[2, 4],
[0, 4],
[1, 3]],
[[1, 3],
[2, 1],
[4, 2]]])
a = tf.gather(params, indices, axis=2, batch_dims=2).numpy()
pprint(a)
a = tf.gather(params, indices, axis=2, batch_dims=-1).numpy()
pprint(a)
print(list(zip(params.numpy().tolist(), indices.numpy().tolist())))
# [([[0, 0, 1, 0, 2],
# [3, 0, 0, 0, 4],
# [0, 5, 0, 6, 0]],
# [[2, 4],
# [0, 4],
# [1, 3]]),
#
# ([[1, 8, 4, 2, 2],
# [9, 6, 2, 3, 0],
# [7, 2, 8, 6, 3]],
# [[1, 3],
# [2, 1],
# [4, 2]])]
def manually_batched_gather_3(params, indices, axis):
batch_dims=2
result = []
for p,i in zip(params, indices):
result_2 = []
print(list(zip(p.numpy().tolist(), i.numpy().tolist())))
for p_2, i_2 in zip(p,i):
r = tf.gather(p_2, i_2, axis=axis-batch_dims)
result_2.append(r)
result.append(result_2)
return tf.stack(result)
manually_batched_gather_3(params, indices, axis=2).numpy()
# <tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
# array([[[1, 2],
# [3, 4],
# [5, 6]],
#
# [[8, 2],
# [2, 6],
# [3, 8]]], dtype=int32)>>
# [([0, 0, 1, 0, 2], [2, 4]),
# ([3, 0, 0, 0, 4], [0, 4]),
# ([0, 5, 0, 6, 0], [1, 3])]
a = tf.gather(params, indices, axis=2, batch_dims=1).numpy()
pprint(a)
a = tf.gather(params, indices, axis=2, batch_dims=-2).numpy()
pprint(a)
manually_batched_gather(params, indices, axis=2).numpy()
# [[[[1 2]
# [0 2]
# [0 0]]
#
# [[0 4]
# [3 4]
# [0 0]]
#
# [[0 0]
# [0 0]
# [5 6]]]
#
#
# [[[8 2]
# [4 8]
# [2 4]]
#
# [[6 3]
# [2 6]
# [0 2]]
#
# [[2 6]
# [8 2]
# [3 8]]]]
indices = tf.constant([
[[1, 0],
[2, 1],
[2, 0]],
[[2, 0],
[0, 1],
[1, 2]]])
a = tf.gather(params, indices, axis=1, batch_dims=1).numpy()
pprint(a)
a = tf.gather(params, indices, axis=1, batch_dims=-2).numpy()
pprint(a)
manually_batched_gather(params, indices, axis=1).numpy()
# array([[[[3, 0, 0, 0, 4],
# [0, 0, 1, 0, 2]],
#
# [[0, 5, 0, 6, 0],
# [3, 0, 0, 0, 4]],
#
# [[0, 5, 0, 6, 0],
# [0, 0, 1, 0, 2]]],
#
#
# [[[7, 2, 8, 6, 3],
# [1, 8, 4, 2, 2]],
#
# [[1, 8, 4, 2, 2],
# [9, 6, 2, 3, 0]],
#
# [[9, 6, 2, 3, 0],
# [7, 2, 8, 6, 3]]]], dtype=int32)>>