numpy.where(condition)的输出不是一个数组,而是一个数组的元组:为什么?

时间:2015-11-17 01:59:32

标签: python arrays numpy

我正在尝试numpy.where(condition[, x, y])功能 从numpy documentation,我得知如果你只给一个数组作为输入,它应该返回数组非零的索引(即" True"):

  

如果只给出条件,则返回元组condition.nonzero(),.   条件为真的索引。

但是如果尝试的话,它会返回两个元素的元组,其中第一个是索引的通缉列表,第二个是null元素:

>>> import numpy as np
>>> array = np.array([1,2,3,4,5,6,7,8,9])
>>> np.where(array>4)
(array([4, 5, 6, 7, 8]),) # notice the comma before the last parenthesis

所以问题是:为什么?这种行为的目的是什么?在什么情况下这是有用的? 实际上,要获得所需的索引列表,我必须添加索引,如np.where(array>4)[0],这似乎......"丑陋"。

附录

我理解(从某些答案)它实际上只是一个元素的元组。我仍然不明白为什么要以这种方式提供输出。为了说明这是不理想的,请考虑以下错误(这首先激发了我的问题):

>>> import numpy as np
>>> array = np.array([1,2,3,4,5,6,7,8,9])
>>> pippo = np.where(array>4)
>>> pippo + 1
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: can only concatenate tuple (not "int") to tuple

这样你就需要做一些索引来访问实际的索引数组:

>>> pippo[0] + 1
array([5, 6, 7, 8, 9])

3 个答案:

答案 0 :(得分:30)

在Python (1)中仅表示1()可以自由添加到群组编号和表达式中以供人类阅读(例如(1+3)*3 v (1+3,)*3)。因此,为了表示1元素元组,它使用(1,)(并要求您也使用它)。

因此

(array([4, 5, 6, 7, 8]),)

是一个元素元组,该元素是一个数组。

如果您将where应用于2d数组,则结果将是2元素元组。

where的结果是可以将其直接插入索引槽,例如

a[where(a>0)]
a[a>0]

应该返回相同的东西

就像

一样
I,J = where(a>0)   # a is 2d
a[I,J]
a[(I,J)]

或者用你的例子:

In [278]: a=np.array([1,2,3,4,5,6,7,8,9])
In [279]: np.where(a>4)
Out[279]: (array([4, 5, 6, 7, 8], dtype=int32),)  # tuple

In [280]: a[np.where(a>4)]
Out[280]: array([5, 6, 7, 8, 9])

In [281]: I=np.where(a>4)
In [282]: I
Out[282]: (array([4, 5, 6, 7, 8], dtype=int32),)
In [283]: a[I]
Out[283]: array([5, 6, 7, 8, 9])

In [286]: i, = np.where(a>4)   # note the , on LHS
In [287]: i
Out[287]: array([4, 5, 6, 7, 8], dtype=int32)  # not tuple
In [288]: a[i]
Out[288]: array([5, 6, 7, 8, 9])
In [289]: a[(i,)]
Out[289]: array([5, 6, 7, 8, 9])

======================

无论输入数组的大小如何,

np.flatnonzero都会显示返回一个数组的正确方法。

In [299]: np.flatnonzero(a>4)
Out[299]: array([4, 5, 6, 7, 8], dtype=int32)
In [300]: np.flatnonzero(a>4)+10
Out[300]: array([14, 15, 16, 17, 18], dtype=int32)

它的医生说:

  

这相当于a.ravel()。nonzero()[0]

事实上,这实际上就是函数所做的事情。

通过扁平化a消除了如何处理多个维度的问题。然后它从元组中获取响应,给你一个简单的数组。对于扁平化,它并没有为1d阵列做出特殊情况。

===========================

@Divakar建议np.argwhere

In [303]: np.argwhere(a>4)
Out[303]: 
array([[4],
       [5],
       [6],
       [7],
       [8]], dtype=int32)

np.transpose(np.where(a>4))

或者如果你不喜欢列向量,你可以再次转置它

In [307]: np.argwhere(a>4).T
Out[307]: array([[4, 5, 6, 7, 8]], dtype=int32)

除了现在它是1xn数组。

我们也可以在where中包裹array

In [311]: np.array(np.where(a>4))
Out[311]: array([[4, 5, 6, 7, 8]], dtype=int32)

where元组([0]i,=transposearray等取出数组的方法很多。

答案 1 :(得分:6)

简短回答:np.where旨在提供一致的输出,无论数组的维数如何。

二维数组有两个索引,因此np.where的结果是包含相关索引的长度为2的元组。这推广为3维的长度为3的元组,4维的长度为4的元组,或N维的长度为N的元组。通过这个规则,很明显在1维中,结果应该是长度为1的元组。

答案 2 :(得分:-1)

只需使用np.asarray功能即可。在你的情况下:

>>> import numpy as np
>>> array = np.array([1,2,3,4,5,6,7,8,9])
>>> pippo = np.asarray(np.where(array>4))
>>> pippo + 1
array([[5, 6, 7, 8, 9]])