我正在转换这个数组:
x = np.array([[0, 0, 1], [1, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 0]])
至:[2, 0, 1, 0, 0]
。
基本上,我想返回每个子数组中第一个1
的索引。但是,我的问题是我不知道如何处理没有1
的场景。如果找不到0
,我希望它返回1
。(例如我的例子)。
下面的代码运行正常,但会为我提到的场景抛出IndexError: index 0 is out of bounds for axis 0 with size 0
:
np.array([np.where(r == 1)[0][0] for r in x])
处理此问题的简单方法是什么?它不需要限制为numpy.where。
我顺便使用Python 3。
答案 0 :(得分:2)
使用mask
1s
然后argmax
沿每行获取第一个匹配的索引以及any
以检查有效行(至少有一行{{ 1}}) -
1
现在,所有mask = x==1
idx = np.where(mask.any(1), mask.argmax(1),0)
上的argmax
都会返回False
。因此,这就完全可以解决所述问题。因此,我们可以简单地使用0
结果。但是在一般情况下,无效说明符,我们称之为mask.argmax(1)
不是invalid_val
,我们可以在0
内指定,就像这样 -
np.where
另一种方法是获取掩码上的第一个匹配索引,然后索引到掩码中以查看是否有任何索引值为idx = np.where(mask.any(1), mask.argmax(1),invalid_val)
并将其设置为False
-
0s
答案 1 :(得分:1)
对代码进行简单修改就是为列表推导添加一个条件:
np.array([np.where(r == 1)[0][0] if 1 in r else 0 for r in x])
# 23.1 µs ± 43.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
获得相同结果的更简洁,更快捷的方法是:
np.argmax(x == 1, axis=1)
# 4.04 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
或等同于:
np.argmin(x != 1, axis=1)
# 4.03 µs ± 13.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)