如何在numpy中找到给定行的值的列索引?

时间:2016-02-24 21:02:35

标签: python numpy

我想在numpy中执行一些相对简单的事情:

  • 如果行中有一个,则返回包含一个+1的列的索引。
  • 如果行中有零个或几个返回0。

但是我的代码非常复杂:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

one_count = (predictions == 1).sum(1)
valid_rows_idx = np.where(one_count==1)

result = np.zeros(predictions.shape[0])
for idx in valid_rows_idx:
    result[idx] = np.where(predictions[idx,:]==1)[1] + 1

如果我打印result,程序将打印出来 [ 1. 0. 4. 0.]这是期望的结果。

我想知道是否有一种使用numpy编写最后一行的简单方法。

1 个答案:

答案 0 :(得分:1)

我不确定是否更好,但您可以尝试使用argmax。此外,您不需要使用for循环和np.where来获取有效索引:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

idx = (predictions == 1).sum(1) == 1
result = np.zeros(predictions.shape[0])
result[idx] = (predictions[idx]==1).argmax(axis=1) + 1

In [55]: result
Out[55]: array([ 1.,  0.,  4.,  0.])

或者你可以使用np.whereargmax在一行中完成所有这些工作:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

In [72]: np.where((predictions==1).sum(1)==1, (predictions==1).argmax(axis=1)+1, 0)
Out[72]: array([1, 0, 4, 0])