天天看點

pytorch nn.Linear(x)中x的資料次元

當x為:

torch.manual_seed(10)  # 随機數種子
x = torch.linspace(1, 10, 50)  # 生成等間距張量
y = 2 * x + 3 * torch.rand(50)
print(x)
           

輸出:

tensor([ 1.0000,  1.1837,  1.3673,  1.5510,  1.7347,  1.9184,  2.1020,  2.2857,
         2.4694,  2.6531,  2.8367,  3.0204,  3.2041,  3.3878,  3.5714,  3.7551,
         3.9388,  4.1224,  4.3061,  4.4898,  4.6735,  4.8571,  5.0408,  5.2245,
         5.4082,  5.5918,  5.7755,  5.9592,  6.1429,  6.3265,  6.5102,  6.6939,
         6.8776,  7.0612,  7.2449,  7.4286,  7.6122,  7.7959,  7.9796,  8.1633,
         8.3469,  8.5306,  8.7143,  8.8980,  9.0816,  9.2653,  9.4490,  9.6327,
         9.8163, 10.0000])
           

時,為(1,50)的資料在代入nn.Linear(x)後:

---> 10         x = self.linear(x)
     11         return x
     12 ### 代碼結束 ###

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
     85 
     86     def forward(self, input):
---> 87         return F.linear(input, self.weight, self.bias)
     88 
     89     def extra_repr(self):

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1370         ret = torch.addmm(bias, input, weight.t())
   1371     else:
-> 1372         output = input.matmul(weight.t())
   1373         if bias is not None:
   1374             output += bias

RuntimeError: size mismatch, m1: [1 x 50], m2: [1 x 1] at /opt/conda/conda-bld/pytorch_1579022060824/work/aten/src/TH/generic/THTensorMath.cpp:136
           

可以發現在1372處,與weight.t()的(1,1)次元,無法完成矩陣乘法,是以要把x處理成:

x = x.reshape(len(x), 1)  # 輸入 x 張量
           

輸出:

tensor([[ 1.0000],
        [ 1.1837],
        [ 1.3673],
        [ 1.5510],
        [ 1.7347],
        [ 1.9184],
        [ 2.1020],
        [ 2.2857],
        [ 2.4694],
        [ 2.6531],
        [ 2.8367],
        [ 3.0204],
        [ 3.2041],
        [ 3.3878],
        [ 3.5714],
        [ 3.7551],
        [ 3.9388],
        [ 4.1224],
        [ 4.3061],
        [ 4.4898],
        [ 4.6735],
        [ 4.8571],
        [ 5.0408],
        [ 5.2245],
        [ 5.4082],
        [ 5.5918],
        [ 5.7755],
        [ 5.9592],
        [ 6.1429],
        [ 6.3265],
        [ 6.5102],
        [ 6.6939],
        [ 6.8776],
        [ 7.0612],
        [ 7.2449],
        [ 7.4286],
        [ 7.6122],
        [ 7.7959],
        [ 7.9796],
        [ 8.1633],
        [ 8.3469],
        [ 8.5306],
        [ 8.7143],
        [ 8.8980],
        [ 9.0816],
        [ 9.2653],
        [ 9.4490],
        [ 9.6327],
        [ 9.8163],
        [10.0000]])

程序已結束,退出代碼為 0
           

繼續閱讀