天天看点

torch.stack()函数的使用理解

torch.stack(sequence, dim)

sqequence– 待连接的张量序列

dim (int) – 插入的维度。

torch.stack()函数和torch.cat()有所不同,torch.stack()并不在已有的维度进行拼接,而是沿着新的维度进行拼接。

我在使用torch.stack()产生了两个问题:

1.怎么确定新的维度产生在哪里?

2.指定了新维度后要怎么拼接?

下面我以两个张量来说明,分别是A和B

A = torch.arange(6.0).reshape(2,3)
B = torch.linspace(0,10,6).reshape(2,3)
           

A和B是这样的2维张量

A= tensor([[0., 1., 2.],
        [3., 4., 5.]])       
B= tensor([[ 0.,  2.,  4.],
        [ 6.,  8., 10.]])       
           

下面说说我对torch.stack()函数的使用理解:

既然知道函数会为张量产生一个新维度,那么我们可以假设,令A和B维度升级,从(2,3)变为(1,2,3),即:

A1= tensor([[[0., 1., 2.],
        [3., 4., 5.]]])       
B1= tensor([[[ 0.,  2.,  4.],
        [ 6.,  8., 10.]]])       
#A1、B1比A、B在最外层多了一组括号[]
           

这样,接下里就很好解释了。

参数dim表示相连维度在这3维里的索引,如用link表示连接维度:

dim=0时,(link,#,#)

dim=1时,(#,link,#)

dim=2时,(#,#,link)

link所在的维度是哪个,就把A1和B1对应维度里的元素逐个相连。

下面我对每个维度都演示一遍

dim=0

F1 = torch.stack((A,B),dim=0)
print('F1=',F1)
print('F1的维度是',F1.shape)
           

运行结果:

F1= tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.]],
        [[ 0.,  2.,  4.],
         [ 6.,  8., 10.]]])
F1的维度是 torch.Size([2, 2, 3])
           

用上面的说法来理解,这相当于在A1和B1的第0维度里,每个元素依次相连,每对连接元素用[ ]包装。

A1的第0维度里只有A一个元素

B1的第0维度里只有B一个元素

因此如上所示,F1的结果其实就是[A,B]

dim=1

F2 = torch.stack((A,B),dim=1)
print('F2=',F2)
print('F2的维度是',F2.shape)
           

运行结果:

F2= tensor([[[ 0.,  1.,  2.],
         [ 0.,  2.,  4.]],
        [[ 3.,  4.,  5.],
         [ 6.,  8., 10.]]])
F2的维度是 torch.Size([2, 2, 3])
           

沿用上面的理解

A1的第1维度里的两个元素:[0. 1, 2],[3, 4, 5]

B1的第1维度里的两个元素:[0, 2, 4],[6, 8, 10]

[0. 1, 2]和[0, 2, 4]相连,[ ]包起来

[3, 4, 5]和[6, 8, 10]相连,[ ]包起来

最后给以上两组用[ ]包起来

dim=2

F3 = torch.stack((A,B),dim=2)
print('F3=',F3)
print('F3的维度是',F3.shape)
           

运行结果

F3= tensor([[[ 0.,  0.],
         [ 1.,  2.],
         [ 2.,  4.]],
        [[ 3.,  6.],
         [ 4.,  8.],
         [ 5., 10.]]])
F3的维度是 torch.Size([2, 3, 2])
           

A1和B1的第2维度里每个元素依次相连,每对连接元素用[]包装

A1第2维度里的元素:0,1,2,3,4,5

B1第2维度里的元素:0,2,4,6,8,10

两两相连后打包[0,0] [1,2] [2,4] [3,6] [4,8] [5,10]

由于[0,1,2]和[0,2,4]的第1维度属性是0

[3,4,5]和[6,8,10]的第1维度属性是1

所以把第一维度属性是0的和是1的单独打包

即[[0,0],[1,2],[2,4]]和[[3,6],[4,8],[5,10]]

最后将以上两组一起[]包起来

继续阅读