當x為:
torch.manual_seed(10) # 随機數種子
x = torch.linspace(1, 10, 50) # 生成等間距張量
y = 2 * x + 3 * torch.rand(50)
print(x)
輸出:
tensor([ 1.0000, 1.1837, 1.3673, 1.5510, 1.7347, 1.9184, 2.1020, 2.2857,
2.4694, 2.6531, 2.8367, 3.0204, 3.2041, 3.3878, 3.5714, 3.7551,
3.9388, 4.1224, 4.3061, 4.4898, 4.6735, 4.8571, 5.0408, 5.2245,
5.4082, 5.5918, 5.7755, 5.9592, 6.1429, 6.3265, 6.5102, 6.6939,
6.8776, 7.0612, 7.2449, 7.4286, 7.6122, 7.7959, 7.9796, 8.1633,
8.3469, 8.5306, 8.7143, 8.8980, 9.0816, 9.2653, 9.4490, 9.6327,
9.8163, 10.0000])
時,為(1,50)的資料在代入nn.Linear(x)後:
---> 10 x = self.linear(x)
11 return x
12 ### 代碼結束 ###
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
85
86 def forward(self, input):
---> 87 return F.linear(input, self.weight, self.bias)
88
89 def extra_repr(self):
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1370 ret = torch.addmm(bias, input, weight.t())
1371 else:
-> 1372 output = input.matmul(weight.t())
1373 if bias is not None:
1374 output += bias
RuntimeError: size mismatch, m1: [1 x 50], m2: [1 x 1] at /opt/conda/conda-bld/pytorch_1579022060824/work/aten/src/TH/generic/THTensorMath.cpp:136
可以發現在1372處,與weight.t()的(1,1)次元,無法完成矩陣乘法,是以要把x處理成:
x = x.reshape(len(x), 1) # 輸入 x 張量
輸出:
tensor([[ 1.0000],
[ 1.1837],
[ 1.3673],
[ 1.5510],
[ 1.7347],
[ 1.9184],
[ 2.1020],
[ 2.2857],
[ 2.4694],
[ 2.6531],
[ 2.8367],
[ 3.0204],
[ 3.2041],
[ 3.3878],
[ 3.5714],
[ 3.7551],
[ 3.9388],
[ 4.1224],
[ 4.3061],
[ 4.4898],
[ 4.6735],
[ 4.8571],
[ 5.0408],
[ 5.2245],
[ 5.4082],
[ 5.5918],
[ 5.7755],
[ 5.9592],
[ 6.1429],
[ 6.3265],
[ 6.5102],
[ 6.6939],
[ 6.8776],
[ 7.0612],
[ 7.2449],
[ 7.4286],
[ 7.6122],
[ 7.7959],
[ 7.9796],
[ 8.1633],
[ 8.3469],
[ 8.5306],
[ 8.7143],
[ 8.8980],
[ 9.0816],
[ 9.2653],
[ 9.4490],
[ 9.6327],
[ 9.8163],
[10.0000]])
程序已結束,退出代碼為 0