torch.cat 和 torch.stack看起來相似但是性質還是不同的
使用python中的list清單收錄tensor時,然後将list清單轉化成tensor時,會報錯。這個時候就要使用torch.stack進行堆疊,轉化成tensor。
- torch.cat()
torch.cat(tensors,dim=0,out=None)→ Tensor
torch.cat()對tensors沿指定次元拼接,但傳回的Tensor的維數不會變
import torch
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.cat((a, b))
a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([4, 3]))
可以看到c和a、b一樣都是二維的。
- torch.stack()
torch.stack(tensors,dim=0,out=None)→ Tensor
torch.stack()同樣是對tensors沿指定次元拼接,但傳回的Tensor會多一維
import torch
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.stack((a, b))
a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 2, 3]))
可以看到c是三維的,比a、b多了一維。