天天看點

pytorch中index_select()

**

pytorch中index_select()的用法

**

a = torch.linspace(1, 12, steps=12).view(3, 4)
print(a)
b = torch.index_select(a, 0, torch.tensor([0, 2]))
print(b)
print(a.index_select(0, torch.tensor([0, 2])))
c = torch.index_select(a, 1, torch.tensor([1, 3]))
print(c)

           

先定義了一個tensor,這裡用到了linspace和view方法。

第一個參數是索引的對象,第二個參數0表示按行索引,1表示按列進行索引,第三個參數是一個tensor,就是索引的序号,比如b裡面tensor[0, 2]表示第0行和第2行,c裡面tensor[1, 3]表示第1列和第3列。

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],
        [ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],
        [ 9., 10., 11., 12.]])
tensor([[ 2.,  4.],
        [ 6.,  8.],
        [10., 12.]])
           

繼續閱讀