文章目錄
- Pytorch基本資料類型及操作(二)
-
- 1. 索引選取
- 2. 切片選取
- 3. 步長選取
- 4. 用...選取
- 5. 使用mask來索引
- 6. 使用 take 打成一維
Pytorch基本資料類型及操作(二)
1. 索引選取
a = torch.rand(4, 3, 28, 28) # 定義a是一個 4張28*28的RGB圖 的張量
# 單個選取
print(a.shape) # torch.Size([4, 3, 28, 28])
print(a[0].shape) # torch.Size([3, 28, 28])
print(a[0, 0].shape) # torch.Size([28, 28])
print(a[0, 0, 0].shape) # torch.Size([28])
print(a[0, 0, 2, 4].shape) # torch.Size([])
- a[0]:了解為取第 0 張圖檔,這張圖檔有 3 個通道,每個通道都是 28 * 28 的
- a[0, 0]:了解為取第 0 張圖檔的第 0 個通道,這個通道是 28 * 28 的
- a[0, 0, 0]:了解為取第 0 張圖檔的第 0 個通道的第 0 行像素點,這一行一共有 28 個像素點
- a[0, 0, 2, 4]:了解為取第 0 張圖檔的第 0 個通道的第 2 行第 4 列的像素點,它是一個點
- a[:2]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔有 3 個通道,每個通道都是 28 * 28 的
2. 切片選取
a = torch.rand(4, 3, 28, 28) # 定義a是一個 4張28*28的RGB圖 的張量
# 連續選取
print(a[:2].shape) # torch.Size([2, 3, 28, 28])
print(a[:2, :1, :, :].shape) # torch.Size([2, 1, 28, 28])
print(a[:2, 1:, :, :].shape) # torch.Size([2, 2, 28, 28])
print(a[:2, -1:, :, :].shape) # torch.Size([2, 1, 28, 28])
- a[:2, :1]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔取第 0 個通道,每個通道都是 28 * 28 的
- a[:2, 1:, :, :]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔取第 1 個和第 2 個通道,每個通道都是 28 * 28 的
- a[:2, -1:, :, :]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔取第 2 個通道,每個通道都是 28 * 28 的
3. 步長選取
a = torch.rand(4, 3, 28, 28) # 定義a是一個 4張28*28的RGB圖 的張量
# 間隔選取
print(a[:, :, 0:28:2, 0:28:2].shape) # torch.Size([4, 3, 14, 14])
print(a[:, :, ::2, ::2].shape) # torch.Size([4, 3, 14, 14])
- a[:, :, 0:28:2, 0:28:2]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔有 3 個通道,對行和列像素點以步長為2從0至28取行和列
- a[:, :, ::2, ::2]:了解為取第 1 張和第 2 張和第 3 張圖檔,每張圖檔有 3 個通道,對行和列像素點以步長為2從0至28取行和列
- dim:表示要操作的次元
- index:表示該次元下的哪些值,這裡的index必須是一個tensor,不能直接是一個list
a = torch.rand(4, 3, 28, 28) # 定義a是一個 4張28*28的RGB圖 的張量
print(a.index_select(0, torch.tensor([0, 2])).shape) # torch.Size([2, 3, 28, 28])
print(a.index_select(1, torch.tensor([1, 2])).shape) # torch.Size([4, 2, 28, 28])
print(a.index_select(2, torch.arange(28)).shape) # torch.Size([4, 3, 28, 28])
print(a.index_select(2, torch.arange(8)).shape) # torch.Size([4, 3, 8, 28])
- a.index_select(0, torch.tensor([0, 2])):表示要操作第 0 次元,即圖檔張數次元,取第 0 張和第 2 張圖檔
- a.index_select(1, torch.tensor([1, 2])):表示要操作第 1 次元,即圖檔通道次元,取第 1 個和第 2 個通道
- a.index_select(2, torch.arange(28)):表示要操作第 2 個次元,即圖檔高度像素點次元,取前 28 個行
- a.index_select(2, torch.arange(8)):表示要操作第 2 個次元,即圖檔高度像素點次元,取前 8 個行
4. 用…選取
-
可以取代...
,增加友善性:
print(a[...].shape) # torch.Size([4, 3, 28, 28])
print(a[0, ...].shape) # torch.Size([3, 28, 28])
print(a[:, 1, ...].shape) # torch.Size([4, 28, 28])
print(a[..., :2].shape) # torch.Size([4, 3, 28, 2])
- a[…]:表示選取所有次元(即 4 個次元)上的所有值,相當于a[:, :, :, :]
- a[0, …]:表示選取後3個次元的所有值
- a[:, 1, …]:表示選取第 1 和次元圖檔張數中所有值,選取第二個次元RGB三個通道中第1個通道,選取三四次元所有值
- a[…, :2]:表示選取前三個次元所有值,第四個次元寬度像素點次元中取前2個值
5. 使用mask來索引
x = torch.randn(3, 4)
print(x)
'''
tensor([[-0.4146, -0.1112, -0.6213, -0.3464],
[-1.0482, 0.2925, 1.0796, 0.1143],
[-0.7203, 0.5699, 1.3800, -0.3570]])
'''
mask = x.ge(0.5)
print(mask)
'''
tensor([[False, False, False, False],
[False, False, True, False],
[False, True, True, False]])
'''
c = torch.masked_select(x, mask)
print(c) # tensor([1.0796, 0.5699, 1.3800])
print(c.shape) # torch.Size([3])
print(c.dim()) # 1
- mask = x.ge(0.5):表示将 x 中所有大于等于 0.5 的設定為 True ,反之設定為 False ,生成由 True 和 False 組成的Tensor
- masked_select 函數是從 mask 中取出來所有值為 True 的值,形成一個 1 維的Tensor
6. 使用 take 打成一維
def take(input, index) -> Tensor
- input:是一個Tensor ,将input變成一維的Tensor
- index:是一個Tensor,表示從打平後的一維input中取出哪些下标的值
- 傳回值:由index為下标的值組成的一維Tensor
# src是一個二維的 2 * 3 的Tensor
src = torch.tensor([
[1, 2, 3],
[4, 5, 6]
])
dst = torch.take(src, torch.tensor([0, 2, 5]))
print(dst) # tensor([1, 3, 6])
print(dst.shape) # torch.Size([3])
print(dst.dim()) # 1
- torch.take(src, torch.tensor([0, 2, 5])) :表示将 src 打平成一維後,取出下标為 0、2、5的值組織成一個一維Tensor