天天看點

【Pytorch】pytorch基本資料類型及操作(二)Pytorch基本資料類型及操作(二)

文章目錄

  • Pytorch基本資料類型及操作(二)
    • 1. 索引選取
    • 2. 切片選取
    • 3. 步長選取
    • 4. 用...選取
    • 5. 使用mask來索引
    • 6. 使用 take 打成一維

Pytorch基本資料類型及操作(二)

1. 索引選取

a = torch.rand(4, 3, 28, 28)     # 定義a是一個 4張28*28的RGB圖 的張量

# 單個選取
print(a.shape)          # torch.Size([4, 3, 28, 28])
print(a[0].shape)       # torch.Size([3, 28, 28])
print(a[0, 0].shape)    # torch.Size([28, 28])
print(a[0, 0, 0].shape) # torch.Size([28])
print(a[0, 0, 2, 4].shape)  # torch.Size([])
           
  • a[0]:了解為取第 0 張圖檔,這張圖檔有 3 個通道,每個通道都是 28 * 28 的
  • a[0, 0]:了解為取第 0 張圖檔的第 0 個通道,這個通道是 28 * 28 的
  • a[0, 0, 0]:了解為取第 0 張圖檔的第 0 個通道的第 0 行像素點,這一行一共有 28 個像素點
  • a[0, 0, 2, 4]:了解為取第 0 張圖檔的第 0 個通道的第 2 行第 4 列的像素點,它是一個點
  • a[:2]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔有 3 個通道,每個通道都是 28 * 28 的

2. 切片選取

a = torch.rand(4, 3, 28, 28)     # 定義a是一個 4張28*28的RGB圖 的張量

# 連續選取
print(a[:2].shape)              # torch.Size([2, 3, 28, 28])
print(a[:2, :1, :, :].shape)    # torch.Size([2, 1, 28, 28])
print(a[:2, 1:, :, :].shape)    # torch.Size([2, 2, 28, 28])
print(a[:2, -1:, :, :].shape)   # torch.Size([2, 1, 28, 28])
           
  • a[:2, :1]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔取第 0 個通道,每個通道都是 28 * 28 的
  • a[:2, 1:, :, :]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔取第 1 個和第 2 個通道,每個通道都是 28 * 28 的
  • a[:2, -1:, :, :]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔取第 2 個通道,每個通道都是 28 * 28 的

3. 步長選取

a = torch.rand(4, 3, 28, 28)     # 定義a是一個 4張28*28的RGB圖 的張量

# 間隔選取
print(a[:, :, 0:28:2, 0:28:2].shape)    # torch.Size([4, 3, 14, 14])
print(a[:, :, ::2, ::2].shape)          # torch.Size([4, 3, 14, 14])
           
  • a[:, :, 0:28:2, 0:28:2]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔有 3 個通道,對行和列像素點以步長為2從0至28取行和列
  • a[:, :, ::2, ::2]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔有 3 個通道,對行和列像素點以步長為2從0至28取行和列
  • dim:表示要操作的次元
  • index:表示該次元下的哪些值,這裡的index必須是一個tensor,不能直接是一個list
a = torch.rand(4, 3, 28, 28)     # 定義a是一個 4張28*28的RGB圖 的張量

print(a.index_select(0, torch.tensor([0, 2])).shape)    # torch.Size([2, 3, 28, 28])
print(a.index_select(1, torch.tensor([1, 2])).shape)    # torch.Size([4, 2, 28, 28])
print(a.index_select(2, torch.arange(28)).shape)        # torch.Size([4, 3, 28, 28])
print(a.index_select(2, torch.arange(8)).shape)         # torch.Size([4, 3, 8, 28])
           
  • a.index_select(0, torch.tensor([0, 2])):表示要操作第 0 次元,即圖檔張數次元,取第 0 張和第 2 張圖檔
  • a.index_select(1, torch.tensor([1, 2])):表示要操作第 1 次元,即圖檔通道次元,取第 1 個和第 2 個通道
  • a.index_select(2, torch.arange(28)):表示要操作第 2 個次元,即圖檔高度像素點次元,取前 28 個行
  • a.index_select(2, torch.arange(8)):表示要操作第 2 個次元,即圖檔高度像素點次元,取前 8 個行

4. 用…選取

  • ...

    可以取代

    :

    ,增加友善性
print(a[...].shape) 		# torch.Size([4, 3, 28, 28])
print(a[0, ...].shape) 		# torch.Size([3, 28, 28])
print(a[:, 1, ...].shape) 	# torch.Size([4, 28, 28])
print(a[..., :2].shape) 	# torch.Size([4, 3, 28, 2])
           
  • a[…]:表示選取所有次元(即 4 個次元)上的所有值,相當于a[:, :, :, :]
  • a[0, …]:表示選取後3個次元的所有值
  • a[:, 1, …]:表示選取第 1 和次元圖檔張數中所有值,選取第二個次元RGB三個通道中第1個通道,選取三四次元所有值
  • a[…, :2]:表示選取前三個次元所有值,第四個次元寬度像素點次元中取前2個值

5. 使用mask來索引

x = torch.randn(3, 4)
print(x)
'''
tensor([[-0.4146, -0.1112, -0.6213, -0.3464],
        [-1.0482,  0.2925,  1.0796,  0.1143],
        [-0.7203,  0.5699,  1.3800, -0.3570]])
'''

mask = x.ge(0.5)	
print(mask)
'''
tensor([[False, False, False, False],
        [False, False,  True, False],
        [False,  True,  True, False]])
'''

c = torch.masked_select(x, mask)
print(c)         # tensor([1.0796, 0.5699, 1.3800])
print(c.shape)   # torch.Size([3])
print(c.dim())   # 1
           
  • mask = x.ge(0.5):表示将 x 中所有大于等于 0.5 的設定為 True ,反之設定為 False ,生成由 True 和 False 組成的Tensor
  • masked_select 函數是從 mask 中取出來所有值為 True 的值,形成一個 1 維的Tensor

6. 使用 take 打成一維

def take(input, index) -> Tensor
           
  • input:是一個Tensor ,将input變成一維的Tensor
  • index:是一個Tensor,表示從打平後的一維input中取出哪些下标的值
  • 傳回值:由index為下标的值組成的一維Tensor
# src是一個二維的 2 * 3 的Tensor
src = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])

dst = torch.take(src, torch.tensor([0, 2, 5]))  
print(dst)          # tensor([1, 3, 6])
print(dst.shape)    # torch.Size([3])
print(dst.dim())    # 1
           
  • torch.take(src, torch.tensor([0, 2, 5])) :表示将 src 打平成一維後,取出下标為 0、2、5的值組織成一個一維Tensor

繼續閱讀