我对numpy.where
的结果意味着什么感到困惑,以及如何使用它来索引数组。
看看下面的代码示例:
import numpy as np
a = np.random.randn(10,10,2)
indices = np.where(a[:,:,0] > 0.5)
我希望indices
数组为2-dim并包含条件为true的索引。我们可以通过
indices = np.array(indices)
indices.shape # (2,120)
所以看起来indices
正在对某种扁平化阵列进行操作,但我无法弄清楚具体如何。更令人困惑的是,
a.shape # (20,20,2)
a[indices].shape # (2,120,20,2)
问题:
如何使用np.where
的输出索引我的数组实际上增长数组的大小?这里发生了什么?
答案 0 :(得分:4)
您的索引基于错误的假设:np.where
返回一些可以立即用于高级索引的内容(它是np.ndarrays
的元组)。但是你把它转换成一个numpy数组(所以它现在是np.ndarray
的{{1}}。
所以
np.ndarrays
为您提供import numpy as np
a = np.random.randn(10,10,2)
indices = np.where(a[:,:,0] > 0.5)
a[:,:,0][indices]
# If you do a[indices] the result would be different, I'm not sure what
# you intended.
找到的元素。如果您将np.where
转换为indices
,则会触发另一种形式的索引(see this section of the numpy docs),并且docs
中的警告消息非常重要。这就是它增加数组总大小的原因。
有关np.array
含义的一些其他信息:您将获得包含np.where
数组的元组。 n
是输入数组的维数。因此,满足条件的第一个元素具有索引n
而不是[0][0], [1][0], ... [n][0]
。因此,在您的情况下,您有(2,120)意味着您有2个维度和120个找到的点。