学习pytorch中,看到文档里关于torch.nonzero的介绍和举例,一维的那个例子还好理解,二维的就不是很理解了,不明白为什么会出现两个[0,1,2,3],于是再网上查看了一些讲解,终于搞明白了是怎么回事。
先看一下原函数中每个元素的介绍
torch.nonzero(input, *, out=None, as_tuple=False)
- input:输入的必须是tensor
- out:输出 Z × N , N 代表输入数据的维度, Z 是总共非0元素的个数
- as_tuple:
输出的每一行为非零元素的索引if as_tuple = False:
输出是每一个维度都有一个一维的张量if as_tuple = True:
看个例子
torch.nonzero(torch.Tensor([[6.6, 0.0, 0.0],
... [0.0, 3.3, 0.0],
... [0.0, 0.0, 1.1]]))
输出结果:
tensor([[0, 0],
[1, 1],
[2, 2]])
首先,当
as_tuple
元素未给出时,默认
as_tuple = False
;
然后根据这里的结果进行一下解读:这个例子里input是2维的,一共有3个非0元素,所以输出是一个
3×2
的张量,表示每个非0元素的索引。读法是从左往右,比如out的第0行
[0,0]
,表示的就是input的第0行的第0个元素是非0元素;同理,out的第1行
[1,1]
,表示的就是input的第1行的第
在这里插入代码片
1个元素是非0元素,等等。
举个例子
我们将input设为一个3维张量,举一反三
torch.nonzero(torch.Tensor([[[1,1,1,0,1],[1,0,0,0,1]],
[[1,1,1,0,1],[1,0,0,0,1]]]))
输出结果:
0 0 0
0 0 1
0 0 2
0 0 4
0 1 0
0 1 4
1 0 0
1 0 1
1 0 2
1 0 4
1 1 0
1 1 4
这个out张量的意思就是:
按行依次从左往右读,第0行第0列第0个元素非0,第0行第0列第1个元素非0,……,第1行第1列第0个元素非0,第1行第1列第4个元素非0.
进一步举例
当
as_tuple = True:
输出是每一个维度都有一个一维的张量
例1
torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
[0.0, 0.4, 0.0, 0.0],
[0.0, 0.0, 1.2, 0.0],
[0.0, 0.0, 0.0,-0.4]]), as_tuple=True)
输出结果:
(tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]))
解读: 组合起来就是(0,0)(1,1)(2,2)(3,3)
例2:
输出结果:
参考:
https://blog.csdn.net/monchin/article/details/79750216
https://blog.csdn.net/qq_36530992/article/details/102836509