天天看點

關于tf.gather函數batch_dims參數用法的了解0 前言1. 不考慮batch_dims2. 批處理(考慮batch_dims)3. 補充4. 參數和傳回值5. 其他相關論述6. 附件

關于tf.gather函數batch_dims參數用法的了解

  • 0 前言
  • 1. 不考慮batch_dims
  • 2. 批處理(考慮batch_dims)
    • 2.1 batch_dims=1
    • 2.2 batch_dims=0
    • 2.3 batch_dims>=2
    • 2.4 batch_dims再降為1
    • 2.5 再将axis降為1
    • 2.6 batch_dims<0
    • 2.7 batch_dims總結
  • 3. 補充
  • 4. 參數和傳回值
  • 5. 其他相關論述
  • 6. 附件

截至發稿(2023年3月2日)之前,全網對這個問題的解釋都不是很清楚(包括官網和英文網際網路),尤其是對

batch_dims

本質實體含義的解釋,以下内容根據

tf.gather

官網進行翻譯,并補充。

0 前言

根據索引

indices

從參數

axis

軸收集切片。 (棄用的參數,應該指下文的

validate_indices

tf.gather(
    params, indices, validate_indices=None, axis=None, batch_dims=0, name=None
)
           

已棄用:一些參數已棄用:(

validate_indices

)。 它們将在未來的版本中被删除。 更新說明:

validate_indices

參數無效。 索引(indices)總是在 CPU 上驗證,從不在 GPU 上驗證。

1. 不考慮batch_dims

根據索引

indices

從軸參數

axis

收集切片。

indices

必須是任意次元(通常是1-D)的整數張量。

Tensor.getitem

适用于标量、

tf.newaxis

和 python切片

tf.gather

擴充索引功能以處理索引(indices)張量。

在最簡單的情況下,它與标量索引功能相同:

>>> params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
>>> params[3].numpy()
b'p3'
>>> tf.gather(params, 3).numpy()
b'p3'
           

最常見的情況是傳遞索引的單軸張量(這不能表示為python切片,因為索引不是連續的):

>>> indices = [2, 0, 2, 5]
>>> tf.gather(params, indices).numpy()
array([b'p2', b'p0', b'p2', b'p5'], dtype=object)
           

過程如下圖所示:

關于tf.gather函數batch_dims參數用法的了解0 前言1. 不考慮batch_dims2. 批處理(考慮batch_dims)3. 補充4. 參數和傳回值5. 其他相關論述6. 附件

索引可以有任何形狀(shape)。 當參數

params

有 1 個軸(axis)時,輸出形狀等于輸入形狀:

>>> tf.gather(params, [[2, 0], [2, 5]]).numpy()
array([[b'p2', b'p0'],
[b'p2', b'p5']], dtype=object)
           

參數

params

也可以有任何形狀。

gather

可以根據參數

axis

(預設為 0)在任何軸(axis)上選擇切片。 它下面例程用于收集(gather)矩陣中的第一行,然後是列:

>>> params = tf.constant([[0, 1.0, 2.0],
...                       [10.0, 11.0, 12.0],
...                       [20.0, 21.0, 22.0],
...                       [30.0, 31.0, 32.0]])
>>> tf.gather(params, indices=[3,1]).numpy()
array([[30., 31., 32.],
              [10., 11., 12.]], dtype=float32)
>>> tf.gather(params, indices=[2,1], axis=1).numpy()
array([[ 2., 1.],
              [12., 11.],
              [22., 21.],
              [32., 31.]], dtype=float32)
           

更一般地說:輸出形狀與輸入形狀相同,索引軸(indexed-axis)由索引(indices)的形狀代替。

>>> def result_shape(p_shape, i_shape, axis=0):
...   return p_shape[:axis] + i_shape + p_shape[axis+1:]
>>> 
>>> result_shape([1, 2, 3], [], axis=1)
[1, 3]
>>> result_shape([1, 2, 3], [7], axis=1)
[1, 7, 3]
>>> result_shape([1, 2, 3], [7, 5], axis=1)
[1, 7, 5, 3]
           

例如下面的例程:

>>> params.shape.as_list()
[4, 3]
>>> indices = tf.constant([[0, 2]])
>>> tf.gather(params, indices=indices, axis=0).shape.as_list()
[1, 2, 3]
>>> tf.gather(params, indices=indices, axis=1).shape.as_list()
[4, 1, 2]
           
>>> params = tf.random.normal(shape=(5, 6, 7, 8))
>>> indices = tf.random.uniform(shape=(10, 11), maxval=7, dtype=tf.int32)
>>> result = tf.gather(params, indices, axis=2)
>>> result.shape.as_list()
[5, 6, 10, 11, 8]
           

這是因為每個索引都從

params

中擷取一個切片,并将其放置在輸出中的相應位置。 對于上面的例子

>>> # For any location in indices
>>> a, b = 0, 1
>>> tf.reduce_all(
...     # the corresponding slice of the result
...     result[:, :, a, b, :] ==
...     # is equal to the slice of `params` along `axis` at the index.
...     params[:, :, indices[a, b], :]
...  ).numpy()
True
           

除此之外,我們再給

indices

增加一個元素,當進行

gather

的時候是沿着

params

axis=1

的上一個次元的元素進行循環的。即

params

axis=0

的元素分别為

[0, 1.0, 2.0]

[10.0, 11.0, 12.0]

[20.0, 21.0, 22.0]

[30.0, 31.0, 32.0]

,然後逐次對這四個元素裡面的

params

axis=1

的元素進行取

indices

對應的元素,四次循環完成整個

gather

>>> tf.gather(params, indices=[[2,1], [1,0]], axis=1).numpy()
array([[[ 2.,  1.],
        [ 1.,  0.]],
       [[12., 11.],
        [11., 10.]],
       [[22., 21.],
        [21., 20.]],
       [[32., 31.],
        [31., 30.]]], dtype=float32)
           

2. 批處理(考慮batch_dims)

batch_dims

參數可以讓您從批次的每個元素中收集不同的項目。

ps:

可以先直接跳到到2.7 batch_dims總結,前後對照閱讀。

2.1 batch_dims=1

使用

batch_dims=1

相當于在

params

indices

的第一個軸(是指

axis=0

軸)上有一個外循環(在

axis=0

軸上的元素上進行循環):

>>> params = tf.constant([
...     [0, 0, 1, 0, 2],
...     [3, 0, 0, 0, 4],
...     [0, 5, 0, 6, 0]])
>>> indices = tf.constant([
...     [2, 4],
...     [0, 4],
...     [1, 3]])
           
>>> tf.gather(params, indices, axis=1, batch_dims=1).numpy()
array([[1, 2],
[3, 4],
[5, 6]], dtype=int32)
           

等價于:

>>> def manually_batched_gather(params, indices, axis):
...            batch_dims=1
...            result = []
...            for p,i in zip(params, indices):  # 這就是上文所說的外循環
...                   r = tf.gather(p, i, axis=axis-batch_dims)
...                   result.append(r)
...            return tf.stack(result)
>>> manually_batched_gather(params, indices, axis=1).numpy()
array([[1, 2],
             [3, 4],
             [5, 6]], dtype=int32)
           

接下來将循環裡

zip

的結果列印如下,說明外循環将

params

indices

在第一個軸上先zip成三個元組

pprint(list(zip(params.numpy().tolist(), indices.numpy().tolist())))
# [([0, 0, 1, 0, 2], [2, 4]), 
# ([3, 0, 0, 0, 4], [0, 4]), 
# ([0, 5, 0, 6, 0], [1, 3])]
           

然後分别對

[0, 0, 1, 0, 2]

[2, 4]

[3, 0, 0, 0, 4]

[0, 4]

[0, 5, 0, 6, 0]

[1, 3]

,沿着重組之後的

axis = 0

(即重組之前的

axis = 1

,這就是為什麼後面所說的必須

axis

>=

batch_dims

)進行

gather

2.2 batch_dims=0

是以可以總結:

batch_dims

是指最終對哪一個次元的張量進行對照

gather

,是以當

batch_dims=0

時,實際上就是将兩個整個張量組包,也就是上面第一階段的省略

batch_dims

的狀态。

此時,相當于将兩個張量在外面添加一個次元之後再

zip

,相當于沒

zip

直接

gather

。是以,以下兩條指令等價,因為

batch_dims

預設值為

params = tf.constant([[  # 相對于上文該張量增加了一個次元
    [0, 0, 1, 0, 2],
    [3, 0, 0, 0, 4],
    [0, 5, 0, 6, 0]]])
indices = tf.constant([[  # 相對于上文該張量增加了一個次元
    [2, 4],
    [0, 4],
    [1, 3]]])
pprint(list(zip(params.numpy().tolist(), indices.numpy().tolist())))
# [([[0, 0, 1, 0, 2], [3, 0, 0, 0, 4], [0, 5, 0, 6, 0]],
#   [[2, 4], [0, 4], [1, 3]])]

tf.gather(params, indices, axis=1, batch_dims=0).numpy()
# 等價于
tf.gather(params, indices, axis=1).numpy()
# 輸出結果為
# array([[[1, 2],
#         [0, 2],
#         [0, 0]],
# 
#        [[0, 4],
#         [3, 4],
#         [0, 0]],
# 
#        [[0, 0],
#         [0, 0],
#         [5, 6]]], dtype=int32)
           

2.3 batch_dims>=2

較高的

batch_dims

值相當于在

params

indices

的外軸上進行多個嵌套循環。 是以整體形狀函數是

>>> def batched_result_shape(p_shape, i_shape, axis=0, batch_dims=0):
...             return p_shape[:axis] + i_shape[batch_dims:] + p_shape[axis+1:]

>>> batched_result_shape(
...              p_shape=params.shape.as_list(),
...              i_shape=indices.shape.as_list(),
...              axis=1,
...              batch_dims=1)
[3, 2]
           
>>> tf.gather(params, indices, axis=1, batch_dims=1).shape.as_list()
[3, 2]
           

舉例來說,

params

indices

升高一個次元,即

batch_dims=2

,這時按照限制條件隻能

axis=2

params = tf.constant([  # 升高一個次元
    [[0, 0, 1, 0, 2],
     [3, 0, 0, 0, 4],
     [0, 5, 0, 6, 0]],

    [[1, 8, 4, 2, 2],
     [9, 6, 2, 3, 0],
     [7, 2, 8, 6, 3]]])
indices = tf.constant([  # 升高一個次元
    [[2, 4],
     [0, 4],
     [1, 3]],

    [[1, 3],
     [2, 1],
     [4, 2]]])
# 進行batch_dims高值gather計算
tf.gather(params, indices, axis=2, batch_dims=2).numpy()
# 則上面的運算等價于
def manually_batched_gather_3d(params, indices, axis):
  batch_dims=2
  result = []
  for p,i in zip(params, indices):  # 這裡面進行了batch_dims層(也就是2層)嵌套for循環
      result_2 = []
      for p_2, i_2 in zip(p,i):
          r = tf.gather(p_2, i_2, axis=axis-batch_dims)  # 這裡告訴我們為什麼axis必須>=batch_dims
          result_2.append(r)
      result.append(result_2)
  return tf.stack(result)
manually_batched_gather_3d(params, indices, axis=2).numpy()
# array([[[1, 2],
#         [3, 4],
#         [5, 6]],
#
#        [[8, 2],
#         [2, 6],
#         [3, 8]]], dtype=int32)
           

下面來解釋一下上面程式的運作過程,在上面的

manually_batched_gather_3d

運作過程中第一層

zip

的作用如下

pprint(list(zip(params.numpy().tolist(), indices.numpy().tolist())))
# 列印得到如下list,該list有兩個元組組成,都是将兩個參數的axis=0軸上的兩個二維張量,分别進行了組包
# [([[0, 0, 1, 0, 2],
#    [3, 0, 0, 0, 4],
#    [0, 5, 0, 6, 0]],  # 到這兒為params的axis=0軸上的[0]二維張量
#   [[2, 4],
#    [0, 4],
#    [1, 3]]),  # 到這兒為indices的axis=0軸上的[0]二維張量
#
#  ([[1, 8, 4, 2, 2],
#    [9, 6, 2, 3, 0],
#    [7, 2, 8, 6, 3]],  # 到這兒為params的axis=0軸上的[1]二維張量
#   [[1, 3],
#    [2, 1],
#    [4, 2]])]  # 到這兒為indices的axis=0軸上的[1]二維張量
           

然後進入第一層for循環的第一次循環,将

zip

之後的兩個元組中的第一個元組,拿過來分别賦給

p

i

p=tf.Tensor(
[[0 0 1 0 2]
 [3 0 0 0 4]
 [0 5 0 6 0]], shape=(3, 5), dtype=int32)
i=tf.Tensor(
[[2 4]
 [0 4]
 [1 3]], shape=(3, 2), dtype=int32)
           

在第二層

for

之前插入,得到第二層的

zip

結果

print(list(zip(p.numpy().tolist(), i.numpy().tolist())))
# [([0, 0, 1, 0, 2], [2, 4]),
#  ([3, 0, 0, 0, 4], [0, 4]),
#  ([0, 5, 0, 6, 0], [1, 3])]
           

則開始第二層for的第一次循環,則

# p_2 = tf.Tensor([0 0 1 0 2], shape=(5,), dtype=int32)
# i_2 = tf.Tensor([2 4], shape=(2,), dtype=int32)
# r = tf.Tensor([1 2], shape=(2,), dtype=int32)
           

這之後第二層for循環再進行2次循環,退回到第一層大循環,第一層大循環再進行一次上述循環即完成了整個循環。

2.4 batch_dims再降為1

你會發現,下面兩條指令等價,即

batch_dims=1

隻有一層循環,隻

zip

一次

tf.gather(params, indices, axis=2, batch_dims=1).numpy()
# 等價于
manually_batched_gather(params, indices, axis=2).numpy()
# [[[[1 2]
#    [0 2]
#    [0 0]]
#
#   [[0 4]
#    [3 4]
#    [0 0]]
#
#   [[0 0]
#    [0 0]
#    [5 6]]]
#
#
#  [[[8 2]
#    [4 8]
#    [2 4]]
#
#   [[6 3]
#    [2 6]
#    [0 2]]
#
#   [[2 6]
#    [8 2]
#    [3 8]]]]
           

2.5 再将axis降為1

還需修改一下

indices

,因為下文有對

indices

的限制——必須在

[0, params.shape[axis]]

範圍内,此時

params.shape

(2, 3, 5)

,則

params.shape[1]=3

,是以

indices

隻能等于

1

2

,如果>=3索引的時候就會溢出。此時還是

batch_dims=1

隻有一層循環,隻

zip

一次,隻是改變了索引軸。

indices = tf.constant([
    [[1, 0],
     [2, 1],
     [2, 0]],

    [[2, 0],
     [0, 1],
     [1, 2]]])
tf.gather(params, indices, axis=1, batch_dims=1).numpy()
# 等價于
manually_batched_gather(params, indices, axis=1).numpy()
# array([[[[3, 0, 0, 0, 4],
#          [0, 0, 1, 0, 2]],
# 
#         [[0, 5, 0, 6, 0],
#          [3, 0, 0, 0, 4]],
# 
#         [[0, 5, 0, 6, 0],
#          [0, 0, 1, 0, 2]]],
# 
# 
#        [[[7, 2, 8, 6, 3],
#          [1, 8, 4, 2, 2]],
# 
#         [[1, 8, 4, 2, 2],
#          [9, 6, 2, 3, 0]],
# 
#         [[9, 6, 2, 3, 0],
#          [7, 2, 8, 6, 3]]]], dtype=int32)>>
           

2.6 batch_dims<0

因為

params

indices

一共由3各次元——

1

2

,其對應的負次元就是

-3

-2

-1

,是以下面兩條指令等價

a = tf.gather(params, indices, axis=2, batch_dims=1).numpy()
pprint(a)
# 等價于
a = tf.gather(params, indices, axis=2, batch_dims=-2).numpy()
pprint(a)
           

2.7 batch_dims總結

故個人認為,

batch_dims

是由batch和dimensions兩個單詞縮寫而成,因為dimensions為複數是以可以翻譯為“批量次元數”(自己翻譯沒有查到文獻),可以指批處理

batch_dims

個次元,如果是正數可以了解成嵌套幾層循環或者進行幾次

zip

,如果是負數需要轉化為對應的正次元再進行上述了解;也可以是指組包到哪一個次元上,如果是負數也同樣适用于這種解釋。

batch_dims

極大的擴充了

gather

的功能,使你可以将

params

indices

在對應的某個次元上分别進行

gather

然後再

stack

ps:關于

batch_dims

的這個解釋同樣也适用于tf.gather_nd。

3. 補充

如果您需要使用諸如 tf.argsort 或 tf.math.top_k 之類的操作的索引,其中索引的最後一個次元在相應位置索引到輸入的最後一個次元,這自然會出現。 在這種情況下,您可以使用 tf.gather(values, indices, batch_dims=-1)。

4. 參數和傳回值

參數

params

從中收集值的

Tensor

(張量)。其秩(rank)必須至少為

axis

+ 1。

indices

索引張量。 必須是以下類型之一:

int32

int64

。 這些值必須在

[0, params.shape[axis]]

範圍内。

validate_indices

已棄用,沒有任何作用。 索引總是在 CPU 上驗證,從不在 GPU 上驗證。

注意:在 CPU 上,如果發現越界索引,則會引發錯誤。 在 GPU 上,如果發現越界索引,則将 0 存儲在相應的輸出值中。

axis

一個

Tensor

((張量))。 必須是以下類型之一:

int32

int64

。 從參數

params

中的

axis

軸收集索引。 必須大于或等于

batch_dims

。 預設為第一個**非批次次元 **。 支援負索引。

batch_dims

一個

integer

(整數)。 批量次元(batch dimensions)的數量。 必須小于或等于

rank(indices)

name

操作的名稱(可選)。
傳回值
一個

Tensor

(張量), 與

params

具有相同的類型。

5. 其他相關論述

下面幾篇部落格,相對于官網手冊都有新的資訊增量,可以作為參考

  • 知網《tf.gather()函數》,使用索引推演的方式在次元和操作兩個方面進行了解,但是其關于

    batch_dims

    的描述不夠充分且有些片面;
  • 知乎《tf.gather()函數總結》,舉了一個新的例子,但是

    batch_dims

    還是隻到了1,沒有很好的歸納其真正的實體意義;
  • CSDN《tf.gather函數》,跟上一篇的情況差不多。

6. 附件

上文用到的調試程式,

可以忽略

import tensorflow as tf
from pprint import pprint

params = tf.constant([[0, 1.0, 2.0],
                      [10.0, 11.0, 12.0],
                      [20.0, 21.0, 22.0],
                      [30.0, 31.0, 32.0]])
a = tf.gather(params, indices=[[2,1], [1,0]], axis=1).numpy()
pprint(a)

params = tf.constant([
    [0, 0, 1, 0, 2],
    [3, 0, 0, 0, 4],
    [0, 5, 0, 6, 0]])
indices = tf.constant([
    [2, 4],
    [0, 4],
    [1, 3]])

a = tf.gather(params, indices, axis=1, batch_dims=1).numpy()
pprint(a)
a = tf.gather(params, indices, axis=1, batch_dims=-1).numpy()
pprint(a)

def manually_batched_gather(params, indices, axis):
  batch_dims=1
  result = []
  for p,i in zip(params, indices):
    r = tf.gather(p, i, axis=axis-batch_dims)
    result.append(r)
  return tf.stack(result)
manually_batched_gather(params, indices, axis=1).numpy()

pprint(list(zip(params.numpy().tolist(), indices.numpy().tolist())))

tf.gather(params, indices, axis=1, batch_dims=0).numpy()
tf.gather(params, indices, axis=1).numpy()
# tf.gather(params, indices, axis=0, batch_dims=0).numpy()

params = tf.constant([[
    [0, 0, 1, 0, 2],
    [3, 0, 0, 0, 4],
    [0, 5, 0, 6, 0]]])
indices = tf.constant([[
    [2, 4],
    [0, 4],
    [1, 3]]])
pprint(list(zip(params.numpy().tolist(), indices.numpy().tolist())))

# [([[0, 0, 1, 0, 2], [3, 0, 0, 0, 4], [0, 5, 0, 6, 0]],
#   [[2, 4], [0, 4], [1, 3]])]

params_1 = [[0, 0, 1, 0, 2],
   [3, 0, 0, 0, 4],
   [0, 5, 0, 6, 0]],
indices_1 = [[2, 4],
   [0, 4],
   [1, 3]]

# a = tf.gather(params_1, indices_1, axis=0).numpy()

params = tf.constant([
    [[0, 0, 1, 0, 2],
     [3, 0, 0, 0, 4],
     [0, 5, 0, 6, 0]],

    [[1, 8, 4, 2, 2],
     [9, 6, 2, 3, 0],
     [7, 2, 8, 6, 3]]])
indices = tf.constant([
    [[2, 4],
     [0, 4],
     [1, 3]],

    [[1, 3],
     [2, 1],
     [4, 2]]])

a = tf.gather(params, indices, axis=2, batch_dims=2).numpy()
pprint(a)
a = tf.gather(params, indices, axis=2, batch_dims=-1).numpy()
pprint(a)

print(list(zip(params.numpy().tolist(), indices.numpy().tolist())))



# [([[0, 0, 1, 0, 2],
#    [3, 0, 0, 0, 4],
#    [0, 5, 0, 6, 0]],
#   [[2, 4],
#    [0, 4],
#    [1, 3]]),
#
#  ([[1, 8, 4, 2, 2],
#    [9, 6, 2, 3, 0],
#    [7, 2, 8, 6, 3]],
#   [[1, 3],
#    [2, 1],
#    [4, 2]])]

def manually_batched_gather_3(params, indices, axis):
  batch_dims=2
  result = []
  for p,i in zip(params, indices):
      result_2 = []
      print(list(zip(p.numpy().tolist(), i.numpy().tolist())))
      for p_2, i_2 in zip(p,i):
          r = tf.gather(p_2, i_2, axis=axis-batch_dims)
          result_2.append(r)
      result.append(result_2)
  return tf.stack(result)
manually_batched_gather_3(params, indices, axis=2).numpy()

# <tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
# array([[[1, 2],
#         [3, 4],
#         [5, 6]],
#
#        [[8, 2],
#         [2, 6],
#         [3, 8]]], dtype=int32)>>

# [([0, 0, 1, 0, 2], [2, 4]),
#  ([3, 0, 0, 0, 4], [0, 4]),
#  ([0, 5, 0, 6, 0], [1, 3])]

a = tf.gather(params, indices, axis=2, batch_dims=1).numpy()
pprint(a)
a = tf.gather(params, indices, axis=2, batch_dims=-2).numpy()
pprint(a)
manually_batched_gather(params, indices, axis=2).numpy()
# [[[[1 2]
#    [0 2]
#    [0 0]]
#
#   [[0 4]
#    [3 4]
#    [0 0]]
#
#   [[0 0]
#    [0 0]
#    [5 6]]]
#
#
#  [[[8 2]
#    [4 8]
#    [2 4]]
#
#   [[6 3]
#    [2 6]
#    [0 2]]
#
#   [[2 6]
#    [8 2]
#    [3 8]]]]

indices = tf.constant([
    [[1, 0],
     [2, 1],
     [2, 0]],

    [[2, 0],
     [0, 1],
     [1, 2]]])

a = tf.gather(params, indices, axis=1, batch_dims=1).numpy()
pprint(a)
a = tf.gather(params, indices, axis=1, batch_dims=-2).numpy()
pprint(a)
manually_batched_gather(params, indices, axis=1).numpy()
# array([[[[3, 0, 0, 0, 4],
#          [0, 0, 1, 0, 2]],
#
#         [[0, 5, 0, 6, 0],
#          [3, 0, 0, 0, 4]],
#
#         [[0, 5, 0, 6, 0],
#          [0, 0, 1, 0, 2]]],
#
#
#        [[[7, 2, 8, 6, 3],
#          [1, 8, 4, 2, 2]],
#
#         [[1, 8, 4, 2, 2],
#          [9, 6, 2, 3, 0]],
#
#         [[9, 6, 2, 3, 0],
#          [7, 2, 8, 6, 3]]]], dtype=int32)>>