python查找两个n乘1个numpy数组列表之间的匹配

时间:2013-02-21 12:21:42

标签: python numpy

我一直在寻找这种问题的解决方案: 例如(因为我的真正问题更复杂):

import numpy

a=[numpy.array([1,2]),numpy.array([2,2]),numpy.array([3,2]),numpy.array([4,2])]
b=[numpy.array([2,2]),numpy.array([3,2]),numpy.array([6,2]),numpy.array([5,2]),numpy.array([5,2])]

ya=numpy.array([1,2,3,4])
size_a=len(a)
size_b=len(b)
yb=numpy.empty((size_b,1))
yb.fill(numpy.nan)

for i in xrange(size_b):
    for j in xrange(size_a):
        if numpy.array_equiv(yb,ya):
            ya[i]=yb[j]

我只想用符合b的一个元素的元素索引的ya值填充yb。由于yb比ya长,所以yb在循环的末尾将包含“nan”是正常的。 下面的代码需要很长时间才能继续。事实上,我不知道它是否有效,因为我没有等到循环的结束......

在实际情况中,ya和yb更长:7007和3525

还有另一种方法来实现我的目标吗?

1 个答案:

答案 0 :(得分:1)

要查找数组列表之间的匹配,最直接的方法是将列表广播为相同的n x m形状;这可以使用np.tile完成,但使用stride_tricks会更快:

a = np.array(a)
b = np.array(b)
shape = (2, a.shape[0], b.shape[0])
from numpy.lib.stride_tricks import as_strided
a = as_strided(a, shape=shape, strides=(a.strides[1], a.strides[0], 0))
b = as_strided(b, shape=shape, strides=(b.strides[1], 0, b.strides[0]))
np.where(np.all(a == b, axis=0))

这给出了结果

(array([1, 2]), array([0, 1]))

即。 a[1] == b[0]a[2] == b[1],没有其他匹配。