官網例子
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
>>> torch.argmax(a, dim=1)
tensor([ 0, 2, 0, 1])
再給出一些例子:
input=torch.FloatTensor([1,3,1,8,0])
output=torch.argmax(input,dim=0)
print('input shape ',input.shape)
print('output shape ',output.shape)
print(output)
程式輸出:
input shape torch.Size([5])
output shape torch.Size([])
tensor(3)
input = torch.FloatTensor([[[1,0],
[0,0]],
[[2,2],
[2,6]],
[[3,7],
[3,3]],
[[9,9],
[0,0]]])
output=torch.argmax(input,dim=0)
print('input shape ',input.shape)
print('output shape ',output.shape)
print(output)
程式輸出:
input shape torch.Size([4, 2, 2])
output shape torch.Size([2, 2])
tensor([[3, 3],
[2, 1]])
首先了解輸出輸出的次元變化,
input shape 4,2,2
argmax(input,dim=0)
output的次元就會少了第零維,變成了 (2,2)
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsIyZuBnLzkDO5UDOwETMyIzMwEjMwIzLc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.png)
如圖所示,四根斜線代表在這4個地方取一個最大值,就是斜線穿過的四個點取一個最大值。
最後就生成了四個值,形狀是(2,2)