我需要确定矩阵a中k个最大值的位置(索引)是否与二进制指示符矩阵位于同一位置,b。
import numpy as np
a = np.matrix([[.8,.2,.6,.4],[.9,.3,.8,.6],[.2,.6,.8,.4],[.3,.3,.1,.8]])
b = np.matrix([[1,0,0,1],[1,0,1,1],[1,1,1,0],[1,0,0,1]])
print "a:\n", a
print "b:\n", b
d = argsort(a)
d[:,2:] # Return whether these indices are in 'b'
返回:
a:
[[ 0.8 0.2 0.6 0.4]
[ 0.9 0.3 0.8 0.6]
[ 0.2 0.6 0.8 0.4]
[ 0.3 0.3 0.1 0.8]]
b:
[[1 0 0 1]
[1 0 1 1]
[1 1 1 0]
[1 0 0 1]]
matrix([[2, 0],
[2, 0],
[1, 2],
[1, 3]])
我想比较从上一个结果返回的索引,如果b
在这些位置有一个,则返回计数。
对于此示例,最终的期望结果将是:
1
2
2
1
换句话说,在a
的第一行中,前2个值仅对应b
等中的一个。
任何想法如何有效地做到这一点?也许argsort在这里是错误的方法。 感谢。
答案 0 :(得分:1)
当您使用argsort
时,您可以从最小0
到最大3
得到它,因此您可以通过[::-1]
撤消0
以获得最大3
至少s = np.argsort(a, axis=1)[:,::-1]
#array([[0, 2, 3, 1],
# [0, 2, 3, 1],
# [2, 1, 3, 0],
# [3, 1, 0, 2]])
:
np.take
现在,您可以使用0
获取最大值所在的1
和第二个最大值所在的s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None]
s = np.take(s.flatten(),s2)
#array([[0, 3, 1, 2],
# [0, 3, 1, 2],
# [3, 1, 0, 2],
# [2, 1, 3, 0]])
:
b
在0
中,np.nan
值应替换为0==np.nan
,以便False
提供b = np.float_(b)
b[b==0] = np.nan
#array([[ 1., nan, nan, 1.],
# [ 1., nan, 1., 1.],
# [ 1., 1., 1., nan],
# [ 1., nan, nan, 1.]])
:
print np.logical_or(s==b-1, s==b).sum(axis=1)
#[[1]
# [2]
# [2]
# [1]]
以下比较将为您提供所需的结果:
n
一般情况下,将a
的{{1}}最大值与二进制b
进行比较:
def check_a_b(a,b,n=2):
b = np.float_(b)
b[b==0] = np.nan
s = np.argsort(a, axis=1)[:,::-1]
s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None]
s = np.take(s.flatten(),s2)
ans = s==(b-1)
for i in range(n-1):
ans = np.logical_or( ans, s==b+i )
return ans.sum(axis=1)
这将在logical_or
中进行成对比较。
答案 1 :(得分:1)
Anothen更简单,更快捷的方法,基于以下事实:
True*1=1, True*0=0, False*0=0, and False*1=0
是:
def check_a_b_new(a,b,n=2):
s = np.argsort(a.view(np.ndarray), axis=1)[:,::-1]
s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None]
s = np.take(s.flatten(),s2)
return ((s < n)*b.view(np.ndarray)).sum(axis=1)
避免0
到np.nan
转换,以及Python for
循环,这会导致n
的值很高。
答案 2 :(得分:0)
为了回应Saullo的巨大帮助,我能够完成他的工作并将解决方案缩减为三条线。谢谢Saullo!
#Inputs
k = 2
a = np.matrix([[.8,.2,.6,.4],[.9,.3,.8,.6],[.2,.6,.8,.4],[.3,.3,.1,.8]])
b = np.matrix([[1,0,0,1],[1,0,1,1],[1,1,1,0],[1,0,0,1]])
print "a:\n", a
print "b:\n", b
# Return values of interest
s = argsort(a.view(np.ndarray), axis=1)[:,::-1]
s2 = s + (arange(s.shape[0])*s.shape[1])[:,None]
out = take(b,s2).view(np.ndarray)[::,:k].sum(axis=1)
print out
给出:
a:
[[ 0.8 0.2 0.6 0.4]
[ 0.9 0.3 0.8 0.6]
[ 0.2 0.6 0.8 0.4]
[ 0.3 0.3 0.1 0.8]]
b:
[[1 0 0 1]
[1 0 1 1]
[1 1 1 0]
[1 0 0 1]]
Out:
[1 2 2 1]