天天看点

torch扩展维度

unsqueeze需要,否则报错

expand(3,2,2)参数就是目标维度。

是复制方式,最后的维度必须能整除

if __name__ == '__main__':

    import torch

    x = torch.Tensor([[1,2], [2,3], [3,4]])
    print(x.size())

    print(x)
    d=x.unsqueeze(1).expand(3,2, 2)
    print(d)