天天看点

numpy中where()函数理解

np.where()有两种用法:

1. np.where(condition, x, y)

如果满足条件(condition)则输出x, 不满足条件则输出y。

例子:

>>> import numpy as np
>>> a = np.arange(10)
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> np.where(a, 0, 1)
array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> np.where(a>2, 0, 1)
array([1, 1, 1, 0, 0, 0, 0, 0, 0, 0])
>>> np.where([[True, False], [True, True]], )
(array([0, 1, 1], dtype=int64), array([0, 0, 1], dtype=int64))
>>> np.where([[True, False], [True, True]], [[1, 3], [6, 4]], [[4, 2], [7, 8]])
array([[1, 2],
       [6, 4]])

           

 np.where([[True, False], [True, True]], [[1, 3], [6, 4]], [[4, 2], [7, 8]])中[[True, False], [True, True]]为condition, 第一个为True,输出1, 第二个False,输出2,第三个True,输出6, 第四个True, 输出4

2. np.where(condition)

只有条件(condition),没有x和y。则输出满足条件 (即非0) 元素的坐标。这里的坐标以tuple的形式给出,通常原数组有多少维,输出的tuple中就包含几个数组,分别对应符合条件元素的各维坐标。

例子:

>>> a = np.array([1, 2, 3, 4])
>>> np.where(a > 3)
(array([3], dtype=int64),)
>>> a[np.where(a > 3)]
array([4])
>>> np.where([[1, 0], [1, 0]])
(array([0, 1], dtype=int64), array([0, 0], dtype=int64))

           

np.where([[1, 0], [0, 1]])中真值为两个1, 坐标分别为[0, 0]和[0, 1]。

参考:https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.where.html