如何处理numpy.where条件不满意的情况?

时间:2017-10-25 11:37:08

标签: python arrays numpy

我正在转换这个数组:

x = np.array([[0, 0, 1], [1, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 0]])

至:[2, 0, 1, 0, 0]

基本上,我想返回每个子数组中第一个1的索引。但是,我的问题是我不知道如何处理没有1的场景。如果找不到0,我希望它返回1。(例如我的例子)。

下面的代码运行正常,但会为我提到的场景抛出IndexError: index 0 is out of bounds for axis 0 with size 0

np.array([np.where(r == 1)[0][0] for r in x])

处理此问题的简单方法是什么?它不需要限制为numpy.where。

我顺便使用Python 3。

2 个答案:

答案 0 :(得分:2)

使用mask 1s然后argmax沿每行获取第一个匹配的索引以及any以检查有效行(至少有一行{{ 1}}) -

1

现在,所有mask = x==1 idx = np.where(mask.any(1), mask.argmax(1),0) 上的argmax都会返回False。因此,这就完全可以解决所述问题。因此,我们可以简单地使用0结果。但是在一般情况下,无效说明符,我们称之为mask.argmax(1)不是invalid_val,我们可以在0内指定,就像这样 -

np.where

另一种方法是获取掩码上的第一个匹配索引,然后索引到掩码中以查看是否有任何索引值为idx = np.where(mask.any(1), mask.argmax(1),invalid_val) 并将其设置为False -

0s

答案 1 :(得分:1)

对代码进行简单修改就是为列表推导添加一个条件:

np.array([np.where(r == 1)[0][0] if 1 in r else 0 for r in x])
# 23.1 µs ± 43.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

获得相同结果的更简洁,更快捷的方法是:

np.argmax(x == 1, axis=1)
# 4.04 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

或等同于:

np.argmin(x != 1, axis=1)
# 4.03 µs ± 13.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)