Python - 在一个矩阵的每一行中找到K max值并与二进制矩阵进行比较

时间:2013-08-07 23:58:39

标签: python numpy matrix

我需要确定矩阵a中k个最大值的位置(索引)是否与二进制指示符矩阵位于同一位置,b。

import numpy as np
a = np.matrix([[.8,.2,.6,.4],[.9,.3,.8,.6],[.2,.6,.8,.4],[.3,.3,.1,.8]])
b = np.matrix([[1,0,0,1],[1,0,1,1],[1,1,1,0],[1,0,0,1]])
print "a:\n", a
print "b:\n", b

d = argsort(a)
d[:,2:] # Return whether these indices are in 'b'

返回:

a:
[[ 0.8  0.2  0.6  0.4]
 [ 0.9  0.3  0.8  0.6]
 [ 0.2  0.6  0.8  0.4]
 [ 0.3  0.3  0.1  0.8]]
b:
[[1 0 0 1]
 [1 0 1 1]
 [1 1 1 0]
 [1 0 0 1]]

matrix([[2, 0],
        [2, 0],
        [1, 2],
        [1, 3]])

我想比较从上一个结果返回的索引,如果b在这些位置有一个,则返回计数。 对于此示例,最终的期望结果将是:

1
2
2
1

换句话说,在a的第一行中,前2个值仅对应b等中的一个。

任何想法如何有效地做到这一点?也许argsort在这里是错误的方法。 感谢。

3 个答案:

答案 0 :(得分:1)

当您使用argsort时,您可以从最小0到最大3得到它,因此您可以通过[::-1]撤消0以获得最大3至少s = np.argsort(a, axis=1)[:,::-1] #array([[0, 2, 3, 1], # [0, 2, 3, 1], # [2, 1, 3, 0], # [3, 1, 0, 2]])

np.take

现在,您可以使用0获取最大值所在的1和第二个最大值所在的s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None] s = np.take(s.flatten(),s2) #array([[0, 3, 1, 2], # [0, 3, 1, 2], # [3, 1, 0, 2], # [2, 1, 3, 0]])

b

0中,np.nan值应替换为0==np.nan,以便False提供b = np.float_(b) b[b==0] = np.nan #array([[ 1., nan, nan, 1.], # [ 1., nan, 1., 1.], # [ 1., 1., 1., nan], # [ 1., nan, nan, 1.]])

print np.logical_or(s==b-1, s==b).sum(axis=1)
#[[1]
# [2]
# [2]
# [1]]

以下比较将为您提供所需的结果:

n

一般情况下,将a的{​​{1}}最大值与二进制b进行比较:

def check_a_b(a,b,n=2):
    b = np.float_(b)
    b[b==0] = np.nan
    s = np.argsort(a, axis=1)[:,::-1]
    s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None]
    s = np.take(s.flatten(),s2)
    ans = s==(b-1)
    for i in range(n-1):
        ans = np.logical_or( ans, s==b+i )
    return ans.sum(axis=1)

这将在logical_or中进行成对比较。

答案 1 :(得分:1)

Anothen更简单,更快捷的方法,基于以下事实:

True*1=1, True*0=0, False*0=0, and False*1=0

是:

def check_a_b_new(a,b,n=2):
    s = np.argsort(a.view(np.ndarray), axis=1)[:,::-1]
    s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None]
    s = np.take(s.flatten(),s2)
    return ((s < n)*b.view(np.ndarray)).sum(axis=1)

避免0np.nan转换,以及Python for循环,这会导致n的值很高。

答案 2 :(得分:0)

为了回应Saullo的巨大帮助,我能够完成他的工作并将解决方案缩减为三条线。谢谢Saullo!

#Inputs
k = 2
a = np.matrix([[.8,.2,.6,.4],[.9,.3,.8,.6],[.2,.6,.8,.4],[.3,.3,.1,.8]])
b = np.matrix([[1,0,0,1],[1,0,1,1],[1,1,1,0],[1,0,0,1]])
print "a:\n", a
print "b:\n", b

# Return values of interest
s = argsort(a.view(np.ndarray), axis=1)[:,::-1]
s2 = s + (arange(s.shape[0])*s.shape[1])[:,None]
out = take(b,s2).view(np.ndarray)[::,:k].sum(axis=1)
print out

给出:

a:
[[ 0.8  0.2  0.6  0.4]
 [ 0.9  0.3  0.8  0.6]
 [ 0.2  0.6  0.8  0.4]
 [ 0.3  0.3  0.1  0.8]]
b:
[[1 0 0 1]
 [1 0 1 1]
 [1 1 1 0]
 [1 0 0 1]]
Out:
[1 2 2 1]