天天看點

tf.nn.embedding_lookup中關于partition_strategy參數詳解tf.nn.embedding_lookup參考資料

  • tfnnembedding_lookup
    • 數學上的原理
    • API介紹
    • 簡單示例
      • 程式
      • 注解
    • partition_strategy參數的示例
      • mod案例1
      • mod案例2
      • div案例1
      • div案例2
  • 參考資料

tf.nn.embedding_lookup

embedding_lookup常用于NLP中将one-hot編碼轉換我對應的向量編碼。

數學上的原理

數學上的原理

假設一共有 m 個物體,每個物體有自己唯一的id,那麼從物體的集合到Rm有一個trivial的嵌入,就是把它映射到 Rm 中的标準基,這種嵌入叫做One-hot embedding/encoding.

應用中一般将物體嵌入到一個低維空間 Rn(n≪m) ,隻需要再compose上一個從 Rm 到 Rn 的線性映射就好了。每一個 n×m 的矩陣M都定義了 Rm 到 Rn 的一個線性映射: x↦Mx 。當 x 是一個标準基向量的時候, Mx 對應矩陣 M <script type="math/tex" id="MathJax-Element-26">M</script>中的一列,這就是對應id的向量表示。這個概念用神經網絡圖來表示如下:

tf.nn.embedding_lookup中關于partition_strategy參數詳解tf.nn.embedding_lookup參考資料

從id(索引)找到對應的One-hot encoding,然後紅色的weight就直接對應了輸出節點的值(注意這裡沒有activation function),也就是對應的embedding向量。

API介紹

API介紹

依據inputs_ids來尋找embedding_params中對應的元素.

embedding_lookup(
     params,   # embedding_params 對應的轉換向量
     ids,      # inputs_ids,标記着要查詢的id
     partition_strategy='mod',   #分割方式 
     name=None,
     validate_indices=True, # deprecated
     max_norm=None
 )
           
參數 description 注解
params A single tensor representing the complete embedding tensor, or a list of P tensors all of same shape except for the first dimension, representing sharded embedding tensors. Alternatively, a PartitionedVariable, created by partitioning along dimension 0. Each element must be appropriately sized for the given partition_strategy. params是由一個tensor或者多個tensor組成的清單(多個tensor組成時,每個tensor除了第一個次元其他次元需相等)
ids A Tensor with type int32 or int64 containing the ids to be looked up in params. ids是一個整型的tensor,ids的每個元素代表要在params中取的每個元素的第0維的邏輯index.
partition_strategy A string specifying the partitioning strategy, relevant if len(params) > 1. Currently “div” and “mod” are supported. Default is “mod”. 邏輯index是由partition_strategy指定,partition_strategy用來設定ids的切分方式,目前有兩種切分方式’div’和’mod’.
傳回值 The results of the lookup are concatenated into a dense tensor. The returned tensor has shape shape(ids) + shape(params)[1:]. 傳回值是一個dense tensor.傳回的shape為shape(ids)+shape(params)[1:]

embedding_lookup中的partition_strategy參數比較難了解(this function is hard to understand, until you get the point!),下面會有特别的解釋。

簡單示例

簡單示例

下面我們通過一個常見的案例來解釋embedding_lookup的用法:

程式

# coding:utf8
import tensorflow as tf
import numpy as np

input_ids = tf.placeholder(dtype=tf.int32, shape=[None])
_input_ids = tf.placeholder(dtype=tf.int32, shape=[, ])

embedding_param = tf.Variable(np.identity(, dtype=np.int32))   # 生成一個8x8的機關矩陣
input_embedding = tf.nn.embedding_lookup(embedding_param, input_ids)
_input_embedding = tf.nn.embedding_lookup(embedding_param, _input_ids)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())


print('embedding:')
print(embedding_param.eval())

var1 = [, , , , , , ]
print('\n var1:')
print(var1)

print('\nprojecting result:')
print(sess.run(input_embedding, feed_dict={input_ids: var1}))

var2 = [[, ], [, ], [, ]]
print('\n _var2:')
print(var2)

print('\n _projecting result:')
print(sess.run(_input_embedding, feed_dict={_input_ids: var2}))

'''

輸出:
embedding:
[[1 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 1]]

 var1:
[1, 2, 6, 4, 2, 5, 7]

projecting result:
[[0 1 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 0 0 0 0 1 0]
 [0 0 0 0 1 0 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 0 0 0 0 0 1]]

 _var2:
[[1, 4], [6, 3], [2, 5]]

 _projecting result:
[[[0 1 0 0 0 0 0 0]
  [0 0 0 0 1 0 0 0]]

 [[0 0 0 0 0 0 1 0]
  [0 0 0 1 0 0 0 0]]

 [[0 0 1 0 0 0 0 0]
  [0 0 0 0 0 1 0 0]]]

'''

           

注解

  • embedding_param參數是一個8*8的機關矩陣(這個這是由一個tensor構成的params,即len(params)=1,partition_strategy隻在len(params)>1時才作用)。
embedding_param=          # embedding_param隻由一個tensor組成  故len(embedding_param) = 
[[1 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 1]]
           
  • 我們ids為var1,照着此id從embedding_param取對應的行元素.
var1 = [1, 2, 6, 4, 2, 5, 7]
    # 1即取第2行  --> [0 1 0 0 0 0 0 0]
    # 2即取第3行  --> [0 0 1 0 0 0 0 0]
    # etc.
           
  • 我們ids為var2,照着此id從embedding_param取對應的行元素
var2 = [[, ], [, ], [, ]]
    '''
    [1, 4] 即取2,5行 
    [[0 1 0 0 0 0 0 0]
     [0 0 0 0 1 0 0 0]]

    後面同理
    ''' 
           

partition_strategy參數的示例

關于partition_strategy參數的示例
api描述 注解
If len(params) > 1, each element id of ids is partitioned between the elements of params according to the partition_strategy. In all strategies, if the id space does not evenly divide the number of partitions, each of the first (max_id + 1) % len(params) partitions will be assigned one more id. 如果len(params) > 1,params的元素分割方式是依據partition_strategy的。如果分段不能整分的話,則前(max_id + 1) % len(params)多分一個id.
If partition_strategy is “mod”, we assign each id to partition p = id % len(params). For instance, 13 ids are split across 5 partitions as: [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]] 例如,如果partition_strategy =’mod’.如果我們的params是由5個tensor組成,他們的第一個次元相加為13,則分割政策為[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]
If partition_strategy is “div”, we assign ids to partitions in a contiguous manner. In this case, 13 ids are split across 5 partitions as: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]] 例如,如果partition_strategy =’div’.如果我們的params是由5個tensor組成,他們的第一個次元相加為13,則分割政策為[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]

看api迷迷糊糊的,就看下面的四個例子,就會明白這個函數的操作方法了~

‘mod’案例1

# coding:utf8
   import tensorflow as tf
   import numpy as np


   def test_embedding_lookup():
       a = np.arange().reshape(, )
       b = np.arange(, ).reshape(, )
       c = np.arange(, ).reshape(, )
       print(a)
       print('\n')
       print(b)
       print('\n')
       print(c)
       print('\n')

       a = tf.Variable(a)
       b = tf.Variable(b)
       c = tf.Variable(c)

       t = tf.nn.embedding_lookup([a, b, c],
           partition_strategy='mod', ids=[, , , , , , ])

       init = tf.global_variables_initializer()
       sess = tf.Session()
       sess.run(init)
       m = sess.run(t)
       print(m)


   test_embedding_lookup()
           
'''
        分析:
         這裡我們注意到params是由[a, b, c]這三個tensor組成。即len(params)=3,且a,b,c這三個tensor的第一次元分别為3,1,3。
         在把這個三個tensor組合過程中,我們按照partition_strategy='mod'政策分割。即每個tensor的元素之間相差len(params).這裡分割方式為[a, b, c]  == [[0,3,6], [1,4,7], [2,5,8]]  
         這裡程式還不知道4和7是找不到對應的元素的,在擷取元素時候會報錯
        a=[[ 0  1  2  3]     = [0, 3, 6]  -->  [0  1  2  3]  = 0
           [ 4  5  6  7]                  -->  [4  5  6  7]  = 3
           [ 8  9 10 11]]                 -->  [8  9 10 11]  = 6

        b=[[12 13 14 15]]    = [1, 4, 7]  -->  [12 13 14 15] = 1
                                          -->  運作時報錯  = 4
                                          -->  運作時報錯  = 7


        c = etc..      


        輸出:
        [[ 0  1  2  3]
         [ 4  5  6  7]
         [ 8  9 10 11]]

        [[12 13 14 15]]

        [[16 17 18 19]
         [20 21 22 23]
         [24 25 26 27]]

        [[ 0  1  2  3]  # 0
         [ 4  5  6  7]  # 3
         [ 8  9 10 11]  # 6
         [12 13 14 15]  # 1
         [16 17 18 19]  # 2
         [20 21 22 23]  # 5
         [24 25 26 27]] # 8

        '''
           

‘mod’案例2

# coding:utf8
   import tensorflow as tf
   import numpy as np


   def test_embedding_lookup():
       a = np.arange().reshape(, )
       b = np.arange(, ).reshape(, )
       c = np.arange(, ).reshape(, )
       print(a)
       print('\n')
       print(b)
       print('\n')
       print(c)
       print('\n')

       a = tf.Variable(a)
       b = tf.Variable(b)
       c = tf.Variable(c)

       t = tf.nn.embedding_lookup([a, c, b],
           partition_strategy='mod', ids=[, , , , , , ])

       init = tf.global_variables_initializer()
       sess = tf.Session()
       sess.run(init)
       m = sess.run(t)
       print(m)


   test_embedding_lookup()
           
'''
        分析:
         這裡我們把params從[a, b, c]改為[a, c, b]這三個tensor組成。a,c,b這三個tensor的第一次元分别為3,3,1。
         在把這個三個tensor組合過程中,依舊是每個tensor的元素之間相差len(params).這裡分割方式為[a, c, b]  == [[0,3,6], [1,4,7], [2,5,8]]  
         這裡程式還不知道4和7是找不到對應的元素的,在擷取元素時候會報錯
        a=[[ 0  1  2  3]     = [0, 3, 6]  -->  [0  1  2  3]  = 0
           [ 4  5  6  7]                  -->  [4  5  6  7]  = 3
           [ 8  9 10 11]]                 -->  [8  9 10 11]  = 6

        c=[[16 17 18 19]     = [1, 4, 7]  -->  [16 17 18 19]  = 1
           [20 21 22 23]                  -->  [20 21 22 23]  = 4
           [24 25 26 27]]                -->  [24 25 26 27]  = 7


        b=[[12 13 14 15]]    = [2, 5, 8]  -->  [12 13 14 15] = 2
                                          -->  運作時報錯  = 5
                                          -->  運作時報錯  = 8      


        輸出:
        [[ 0  1  2  3]
         [ 4  5  6  7]
         [ 8  9 10 11]]

        [[12 13 14 15]]

        [[16 17 18 19]
         [20 21 22 23]
         [24 25 26 27]]

        [[ 0  1  2  3]  # 0
         [ 4  5  6  7]  # 3
         [ 8  9 10 11]  # 6
         [16 17 18 19]  # 1
         [20 21 22 23]  # 4
         [24 25 26 27]  # 7
         [12 13 14 15]] # 2

        '''
           

‘div’案例1

# coding:utf8
   import tensorflow as tf
   import numpy as np


   def test_embedding_lookup():
       a = np.arange().reshape(, )
       b = np.arange(, ).reshape(, )
       c = np.arange(, ).reshape(, )
       print(a)
       print('\n')
       print(b)
       print('\n')
       print(c)
       print('\n')

       a = tf.Variable(a)
       b = tf.Variable(b)
       c = tf.Variable(c)

       t = tf.nn.embedding_lookup([a, b, c],
           partition_strategy='div', ids=[, , , , , ])

       init = tf.global_variables_initializer()
       sess = tf.Session()
       sess.run(init)
       m = sess.run(t)
       print(m)


   test_embedding_lookup()
           
'''
        分析:
         這裡我們把params依舊是[a, b, c],三個tensor的第一次元分别為3,1,3。

         在把這個三個tensor組合過程中,這我們按照partition_strategy='div'政策分割。即每個tensor的元素之間相差1.如果不夠等分的話,前面(max_id+1)%len(params)多分一個元素。這裡一共7個元素,分為3組,即3、2、2配置設定。

         這裡分割方式為[a, b, c]  == [[0,1,2], [3,4], [5,6]]  

         這裡程式還不知道4和7是找不到對應的元素的,在擷取元素時候會報錯
        a=[[ 0  1  2  3]     = [0, 1, 2]  -->  [0  1  2  3]  = 0
           [ 4  5  6  7]                  -->  [4  5  6  7]  = 1
           [ 8  9 10 11]]                 -->  [8  9 10 11]  = 2




        b=[[12 13 14 15]]    = [3, 4]  -->  [12 13 14 15] = 3
                                          -->  運作時報錯  = 4

        c=[[16 17 18 19]     = [5, 6]  -->  [16 17 18 19]  = 5
           [20 21 22 23]                  -->  [20 21 22 23]  = 6
           [24 25 26 27]]                -->  [24 25 26 27]  = 這個是找不到的了   


        輸出:
        [[ 0  1  2  3]
         [ 4  5  6  7]
         [ 8  9 10 11]]

        [[12 13 14 15]]

        [[16 17 18 19]
         [20 21 22 23]
         [24 25 26 27]]

        [[ 0  1  2  3]  # 0
         [ 4  5  6  7]  # 1
         [ 8  9 10 11]  # 2
         [12 13 14 15]  # 3
         [16 17 18 19]  # 5
         [20 21 22 23]] # 6

        '''
           

‘div’案例2

# coding:utf8
   import tensorflow as tf
   import numpy as np


   def test_embedding_lookup():
       a = np.arange().reshape(, )
       b = np.arange(, ).reshape(, )
       c = np.arange(, ).reshape(, )
       print(a)
       print('\n')
       print(b)
       print('\n')
       print(c)
       print('\n')

       a = tf.Variable(a)
       b = tf.Variable(b)
       c = tf.Variable(c)

       t = tf.nn.embedding_lookup([a, c, b],
           partition_strategy='div', ids=[, , , , , ])

       init = tf.global_variables_initializer()
       sess = tf.Session()
       sess.run(init)
       m = sess.run(t)
       print(m)


   test_embedding_lookup()
           
'''
        分析:
         這裡我們把params改為[a, c, b],三個tensor的第一次元分别為3,3,1。

         在把這個三個tensor組合過程中,這我們按照partition_strategy='div'政策分割。這裡一共7個元素,分為3組,即3、2、2配置設定。

         這裡分割方式為[a, c, b]  == [[0,1,2], [3,4], [5,6]]  

         這裡程式還不知道4和7是找不到對應的元素的,在擷取元素時候會報錯
        a=[[ 0  1  2  3]     = [0, 1, 2]  -->  [0  1  2  3]  = 0
           [ 4  5  6  7]                  -->  [4  5  6  7]  = 1
           [ 8  9 10 11]]                 -->  [8  9 10 11]  = 2

        c=[[16 17 18 19]     = [3, 4]  -->  [16 17 18 19]  = 3
           [20 21 22 23]                  -->  [20 21 22 23]  = 4
           [24 25 26 27]]                -->  [24 25 26 27]  = 這個是找不到的了   

        b=[[12 13 14 15]]    = [5, 6]  -->  [12 13 14 15] = 5
                                          -->  運作時報錯  = 6


        輸出:
        [[ 0  1  2  3]
         [ 4  5  6  7]
         [ 8  9 10 11]]

        [[12 13 14 15]]

        [[16 17 18 19]
         [20 21 22 23]
         [24 25 26 27]]

        [[ 0  1  2  3]  # 0
         [ 4  5  6  7]  # 1
         [ 8  9 10 11]  # 2
         [16 17 18 19]  # 3
         [20 21 22 23]  # 4
         [16 17 18 19]] # 5

        '''
           

參考資料

https://stackoverflow.com/questions/34870614/what-does-tf-nn-embedding-lookup-function-do/41922877#41922877?newreg=5119f86ea49b43aa8988a833294ceb3e

https://www.zhihu.com/question/52250059

繼續閱讀