过滤具有行特定条件的numpy数组

时间:2015-02-10 20:29:06

标签: python numpy

假设我有一个2d numpy数组,我想过滤每行传递一定标准的元素。例如,我只想要特定行的第90百分位数以上的元素。我已经提出了这个解决方案:

import numpy as np
a = np.random.random((6,5))
thresholds = np.percentile(a, 90, axis=1)
threshold_2d = np.vstack([thresholds]*a.shape[1]).T
mask = a > threshold_2d
final = np.where(mask, a, np.nan)

它有效并且它被矢量化但感觉有点尴尬,尤其是我创建threshold_2d的部分。有更优雅的方式吗?我可以以某种方式使用np.where自动广播条件而无需创建匹配的2d掩码吗?

1 个答案:

答案 0 :(得分:2)

广播

In [36]: np.random.seed(1023)

In [37]: a = np.random.random((6,5))

In [38]: thresholds = np.percentile(a, 90, axis=1)

In [39]: threshold_2d = np.vstack([thresholds]*a.shape[1]).T

In [40]: a>threshold_2d
Out[40]: 
array([[ True, False, False, False, False],
       [False, False,  True, False, False],
       [False,  True, False, False, False],
       [False, False, False,  True, False],
       [False, False, False, False,  True],
       [False,  True, False, False, False]], dtype=bool)

In [41]: a>thresholds[:,np.newaxis]
Out[41]: 
array([[ True, False, False, False, False],
       [False, False,  True, False, False],
       [False,  True, False, False, False],
       [False, False, False,  True, False],
       [False, False, False, False,  True],
       [False,  True, False, False, False]], dtype=bool)

In [42]: 

numpy.newaxis创建一个长度为1的轴,生成的数组视图具有维度(6,1),并且可以使用a arrray进行广播。