天天看點

torch.cat 和 torch.stack

torch.cat 和 torch.stack看起來相似但是性質還是不同的

使用python中的list清單收錄tensor時,然後将list清單轉化成tensor時,會報錯。這個時候就要使用torch.stack進行堆疊,轉化成tensor。

  • torch.cat()

torch.cat(tensors,dim=0,out=None)→ Tensor

torch.cat()對tensors沿指定次元拼接,但傳回的Tensor的維數不會變

import torch
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.cat((a, b))
a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([4, 3]))
可以看到c和a、b一樣都是二維的。
           
  • torch.stack()

torch.stack(tensors,dim=0,out=None)→ Tensor

torch.stack()同樣是對tensors沿指定次元拼接,但傳回的Tensor會多一維

import torch
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.stack((a, b))
a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 2, 3]))
可以看到c是三維的,比a、b多了一維。
           

繼續閱讀