我想在numpy中执行一些相对简单的事情:
但是我的代码非常复杂:
predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])
one_count = (predictions == 1).sum(1)
valid_rows_idx = np.where(one_count==1)
result = np.zeros(predictions.shape[0])
for idx in valid_rows_idx:
result[idx] = np.where(predictions[idx,:]==1)[1] + 1
如果我打印result
,程序将打印出来
[ 1. 0. 4. 0.]
这是期望的结果。
我想知道是否有一种使用numpy编写最后一行的简单方法。
答案 0 :(得分:1)
我不确定是否更好,但您可以尝试使用argmax
。此外,您不需要使用for循环和np.where
来获取有效索引:
predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])
idx = (predictions == 1).sum(1) == 1
result = np.zeros(predictions.shape[0])
result[idx] = (predictions[idx]==1).argmax(axis=1) + 1
In [55]: result
Out[55]: array([ 1., 0., 4., 0.])
或者你可以使用np.where
和argmax
在一行中完成所有这些工作:
predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])
In [72]: np.where((predictions==1).sum(1)==1, (predictions==1).argmax(axis=1)+1, 0)
Out[72]: array([1, 0, 4, 0])