進階索引
在numpy或pytorch等架構中的對張量的操作不隻提供了類似python清單的切片索引等操作,還提供了進階索引。
進階索引大緻可分為整數數組索引、布爾索引及花式索引三類:
整數數組索引
整數數組索引指使用同形狀的多個數組分别指定元素的所有次元(不指定的次元也可以用切片
:
或省略号
...
與索引數組組合),可精準取出一批指定位置的元素,按照給定數組形狀傳回。
如給定4*3 的二維數組
x = np.arange(12).reshape(4,3)
print(x)
輸出為:
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
1.取出其(0,0),(1,1)和(2,0)位置處的元素。
給出兩個一維數組分别指定第0維和第1維
y = x[[0,1,2], [0,1,0]]
print(y)
#[0 4 6]
2.傳入的位置指定數組也可為多元,取出其[[(0,0),(0,2)],[3,0, 3,2]]的元素
rows = np.array([[0,0],[3,3]])
cols = np.array([[0,2],[0,2]])
y = x[rows,cols]
print (y)
輸出為
[[ 0 2]
[ 9 11]]
舉例應用場景:
改變矩陣對角線元素,進階索引除了用來得到元素,也可以像平常索引一樣對原張量進行操作。
給定一個33 3 的矩陣
x = np.arange(27).reshape(3,3,3)
print(x)
輸出:
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]])
将其標明的對角線元素改變
items = range(3)
x[items,1:, items] = -2
print(x)
輸出
array([[[ 0, 1, 2],
[-2, 4, 5],
[-2, 7, 8]],
[[ 9, 10, 11],
[12, -2, 14],
[15, -2, 17]],
[[18, 19, 20],
[21, 22, -2],
[24, 25, -2]]])
2.布爾索引
我們可以通過一個布爾數組來索引目标數組。
布爾數組性狀為目标數組的次元
給定一個33 3 的矩陣
x = np.arange(27).reshape(3,3,3)
print(x)
輸出:
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]])
使用布爾數組索引位置:
idx1 = np.array([0,1,1], dtype=bool)
idx2 = np.array([[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 1, 1],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]]], dtype=bool)
print(x[idx1])
print(x[idx2])
輸出為
[[[ 9 10 11]
[12 13 14]
[15 16 17]]
[[18 19 20]
[21 22 23]
[24 25 26]]]#x[idx1]
[10 11] #x[idx2]
布爾索引也通過布爾運算(如:比較運算符)來擷取符合指定條件的元素的數組。
以下執行個體擷取大于 5 的元素:
x = np.array([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
print ('我們的數組是:')
print (x)
print ('n')
# 現在我們會列印出大于 5 的元素
print ('大于 5 的元素是:')
print (x[x > 5])
輸出
我們的數組是:
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
大于 5 的元素是:
[ 6 7 8 9 10 11]
以下執行個體使用了
~(取補運算符)來過濾 NaN。
a = np.array([np.nan, 1,2,np.nan,3,4,5])
print (a[~np.isnan(a)])
輸出
[ 1. 2. 3. 4. 5.]
以下執行個體示範如何從數組中過濾掉非複數元素。
a = np.array([1, 2+6j, 5, 3.5+5j])
print (a[np.iscomplex(a)])
輸出
[2.0+6.j 3.5+5.j]
花式索引
花式索引與整數數組索引不同,一次使用一個索引數組指定一個次元來指定某個軸的下标,其他次元可以用切片
:
或省略号
...
與正常索引。
花式索引跟切片不一樣,它總是将資料複制到新數組中,是原張量的副本。
1、傳入順序索引數組
x=np.arange(32).reshape((8,4))
print (x[[4,2,1,7]])
輸出
[[16 17 18 19]
[ 8 9 10 11]
[ 4 5 6 7]
[28 29 30 31]]
2、傳入倒序索引數組
x=np.arange(32).reshape((8,4))
print (x[[-4,-2,-1,-7]])
輸出結果為:
[[16 17 18 19]
[24 25 26 27]
[28 29 30 31]
[ 4 5 6 7]]
3、也可指定多元,傳入多個索引數組(要使用np.ix_)
np.ix_分别指定其兩個次元1,5,7,2行,和0,3,1,2列
x=np.arange(32).reshape((8,4))
print (x[np.ix_([1,5,7,2],[0,3,1,2])])
輸出結果為:
[[ 4 7 5 6]
[20 23 21 22]
[28 31 29 30]
[ 8 11 9 10]]
舉例,類似multi-head,假設一句話的隐藏矩陣為(50,200), 50為句子長,200為隐藏次元,建構兩個詞兩兩組合的矩陣。meshgrid生成兩個(50,50)的矩陣,元素值分别為行号和列号的廣播形式
import torch
h = torch.randn(50,200)
r_idx, c_idx = torch.meshgrid(torch.arange(50), torch.arange(50))
muilti_h = torch.cat(h[r_idx], h[c_idx],dim=-1)# (50,50,200)