天天看點

python擷取數組中大于某一門檻值的那些索引值_numpy,pytorch等架構的進階索引

python擷取數組中大于某一門檻值的那些索引值_numpy,pytorch等架構的進階索引

進階索引

在numpy或pytorch等架構中的對張量的操作不隻提供了類似python清單的切片索引等操作,還提供了進階索引。

進階索引大緻可分為整數數組索引、布爾索引及花式索引三類:

整數數組索引

整數數組索引指使用同形狀的多個數組分别指定元素的所有次元(不指定的次元也可以用切片

:

或省略号

...

與索引數組組合),可精準取出一批指定位置的元素,按照給定數組形狀傳回。

如給定4*3 的二維數組

x = np.arange(12).reshape(4,3)
print(x)
           

輸出為:

[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]
           

1.取出其(0,0),(1,1)和(2,0)位置處的元素。

給出兩個一維數組分别指定第0維和第1維

y = x[[0,1,2],  [0,1,0]]
print(y)
#[0 4 6]
           

2.傳入的位置指定數組也可為多元,取出其[[(0,0),(0,2)],[3,0, 3,2]]的元素

rows = np.array([[0,0],[3,3]]) 
cols = np.array([[0,2],[0,2]])
y = x[rows,cols]  
print (y)
           

輸出為

[[ 0  2]
 [ 9 11]]
           

舉例應用場景:

改變矩陣對角線元素,進階索引除了用來得到元素,也可以像平常索引一樣對原張量進行操作。

給定一個33 3 的矩陣

x = np.arange(27).reshape(3,3,3)
print(x)
           

輸出:

array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])
           

将其標明的對角線元素改變

items = range(3)
x[items,1:, items] = -2
print(x)
           

輸出

array([[[ 0,  1,  2],
        [-2,  4,  5],
        [-2,  7,  8]],

       [[ 9, 10, 11],
        [12, -2, 14],
        [15, -2, 17]],

       [[18, 19, 20],
        [21, 22, -2],
        [24, 25, -2]]])
           

2.布爾索引

我們可以通過一個布爾數組來索引目标數組。

布爾數組性狀為目标數組的次元

給定一個33 3 的矩陣

x = np.arange(27).reshape(3,3,3)
print(x)
           

輸出:

array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])
           

使用布爾數組索引位置:

idx1 = np.array([0,1,1], dtype=bool)
idx2 = np.array([[[ 0,  0,  0],
                    [ 0,  0,  0],
                    [ 0,  0,  0]],

                   [[ 0,  1,  1],
                    [ 0,  0,  0],
                    [ 0,  0,  0]],

                   [[ 0,  0,  0],
                    [ 0,  0,  0],
                    [ 0,  0,  0]]], dtype=bool)

print(x[idx1])
print(x[idx2])
           

輸出為

[[[ 9 10 11]
  [12 13 14]
  [15 16 17]]

 [[18 19 20]
  [21 22 23]
  [24 25 26]]]#x[idx1]

 [10 11] #x[idx2]
           

布爾索引也通過布爾運算(如:比較運算符)來擷取符合指定條件的元素的數組。

以下執行個體擷取大于 5 的元素:

x = np.array([[  0,  1,  2],[  3,  4,  5],[  6,  7,  8],[  9,  10,  11]])  
print ('我們的數組是:')
print (x)
print ('n')
# 現在我們會列印出大于 5 的元素  
print  ('大于 5 的元素是:')
print (x[x >  5])
           

輸出

我們的數組是:
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]


大于 5 的元素是:
[ 6  7  8  9 10 11]
           

以下執行個體使用了

~

(取補運算符)來過濾 NaN。

a = np.array([np.nan,  1,2,np.nan,3,4,5])  
print (a[~np.isnan(a)])
           

輸出

[ 1.   2.   3.   4.   5.]
           

以下執行個體示範如何從數組中過濾掉非複數元素。

a = np.array([1,  2+6j,  5,  3.5+5j])  
print (a[np.iscomplex(a)])
           

輸出

[2.0+6.j  3.5+5.j]
           

花式索引

花式索引與整數數組索引不同,一次使用一個索引數組指定一個次元來指定某個軸的下标,其他次元可以用切片

:

或省略号

...

與正常索引。

花式索引跟切片不一樣,它總是将資料複制到新數組中,是原張量的副本。

1、傳入順序索引數組

x=np.arange(32).reshape((8,4))
print (x[[4,2,1,7]])
           

輸出

[[16 17 18 19]
 [ 8  9 10 11]
 [ 4  5  6  7]
 [28 29 30 31]]
           

2、傳入倒序索引數組

x=np.arange(32).reshape((8,4))
print (x[[-4,-2,-1,-7]])
           

輸出結果為:

[[16 17 18 19]
 [24 25 26 27]
 [28 29 30 31]
 [ 4  5  6  7]]
           

3、也可指定多元,傳入多個索引數組(要使用np.ix_)

np.ix_分别指定其兩個次元1,5,7,2行,和0,3,1,2列

x=np.arange(32).reshape((8,4))
print (x[np.ix_([1,5,7,2],[0,3,1,2])])
           

輸出結果為:

[[ 4  7  5  6]
 [20 23 21 22]
 [28 31 29 30]
 [ 8 11  9 10]]
           

舉例,類似multi-head,假設一句話的隐藏矩陣為(50,200), 50為句子長,200為隐藏次元,建構兩個詞兩兩組合的矩陣。meshgrid生成兩個(50,50)的矩陣,元素值分别為行号和列号的廣播形式

import torch
h = torch.randn(50,200)
r_idx, c_idx = torch.meshgrid(torch.arange(50), torch.arange(50))
muilti_h = torch.cat(h[r_idx], h[c_idx],dim=-1)# (50,50,200)