加速查找等效行的python中的ndarray比较

时间:2016-10-30 18:17:36

标签: performance python-2.7 numpy

在Python 2.7中,是否有一种快速比较ndarray中等效行的方法? 我在一些坐标上应用对称操作,我将其存储在形状数组(N,4)的每一行中。我需要一种方法来判断我的转换是否将坐标映射回等效位置。需要注意的是,即使位置可能相同,它们也存储在数组的不同行中,因此这需要在比较之前对数组进行排序。如果我只需要调用一次就可以了,但是在我的代码中这个函数被称为~10,000次。

对此进行基准测试表明这需要约60μs:

%timeit structs_are_equiv_old(a,b)
The slowest run took 6.36 times longer than the fastest. This could mean that     
an intermediate result is being cached.
10000 loops, best of 3: 59.6 µs per loop

有没有办法加快这种比较?

def structs_are_equiv(a, b):
    """
    compares two numpy arrays row by row to determine if they contain the   
    coordinates after the application of a transformation operation.
    """
    assert a.shape == b.shape
    a_temp = a[ np.lexsort( (a[:,3], a[:,2], a[:,1], a[:,0]) ) ]
    b_temp = b[ np.lexsort( (b[:,3], b[:,2], b[:,1], b[:,0]) ) ]

    return np.allclose( a_temp, b_temp )

示例a和b(注意第一列不参与转换,只是表示存储在坐标处的对象类型的一种方式):

a = array([[ 1.      ,  0.      ,  0.5     ,  0.271149],
   [ 1.      ,  0.5     ,  0.5     ,  0.271149],
   [ 1.      ,  0.      ,  0.      ,  0.303063],
   [ 1.      ,  0.5     ,  0.      ,  0.303063],
   [ 2.      ,  0.25    ,  0.      ,  0.358071],
   [ 2.      ,  0.75    ,  0.      ,  0.358071],
   [ 1.      ,  0.25    ,  0.306215,  0.358071],
   [ 1.      ,  0.75    ,  0.306215,  0.358071],
   [ 2.      ,  0.      ,  0.5     ,  0.358071],
   [ 2.      ,  0.5     ,  0.5     ,  0.358071],
   [ 1.      ,  0.25    ,  0.693785,  0.358071],
   [ 1.      ,  0.75    ,  0.693785,  0.358071],
   [ 1.      ,  0.      ,  0.      ,  0.413078],
   [ 1.      ,  0.5     ,  0.      ,  0.413078],
   [ 1.      ,  0.      ,  0.5     ,  0.444992],
   [ 1.      ,  0.5     ,  0.5     ,  0.444992],
   [ 2.      ,  0.      ,  0.      ,  0.5     ],
   [ 2.      ,  0.5     ,  0.      ,  0.5     ],
   [ 1.      ,  0.25    ,  0.193785,  0.5     ],
   [ 1.      ,  0.75    ,  0.193785,  0.5     ],
   [ 2.      ,  0.25    ,  0.5     ,  0.5     ],
   [ 2.      ,  0.75    ,  0.5     ,  0.5     ],
   [ 1.      ,  0.25    ,  0.806215,  0.5     ],
   [ 1.      ,  0.75    ,  0.806215,  0.5     ],
   [ 1.      ,  0.      ,  0.5     ,  0.555008],
   [ 1.      ,  0.5     ,  0.5     ,  0.555008],
   [ 1.      ,  0.      ,  0.      ,  0.586922],
   [ 1.      ,  0.5     ,  0.      ,  0.586922],
   [ 2.      ,  0.25    ,  0.      ,  0.641929],
   [ 2.      ,  0.75    ,  0.      ,  0.641929],
   [ 1.      ,  0.25    ,  0.306215,  0.641929],
   [ 1.      ,  0.75    ,  0.306215,  0.641929],
   [ 2.      ,  0.      ,  0.5     ,  0.641929],
   [ 2.      ,  0.5     ,  0.5     ,  0.641929],
   [ 1.      ,  0.25    ,  0.693785,  0.641929],
   [ 1.      ,  0.75    ,  0.693785,  0.641929],
   [ 1.      ,  0.      ,  0.      ,  0.696937],
   [ 1.      ,  0.5     ,  0.      ,  0.696937],
   [ 1.      ,  0.      ,  0.5     ,  0.728851],
   [ 1.      ,  0.5     ,  0.5     ,  0.728851],
   [ 0.      ,  0.117635,  0.5     ,  0.238728],
   [ 0.      ,  0.617635,  0.5     ,  0.238728],
   [ 0.      ,  0.      ,  0.114216,  0.270642],
   [ 0.      ,  0.5     ,  0.114216,  0.270642],
   [ 0.      ,  0.      ,  0.      ,  0.270642],
   [ 0.      ,  0.5     ,  0.      ,  0.270642],
   [ 0.      ,  0.617635,  0.5     ,  0.761272],
   [ 0.      ,  0.117635,  0.5     ,  0.761272],
   [ 0.      ,  0.5     ,  0.114216,  0.729358],
   [ 0.      ,  0.      ,  0.114216,  0.729358],
   [ 0.      ,  0.5     ,  0.      ,  0.729358],
   [ 0.      ,  0.      ,  0.      ,  0.729358],
   [ 0.      ,  0.25    ,  0.306215,  0.401299],
   [ 0.      ,  0.75    ,  0.306215,  0.401299],
   [ 0.      ,  0.25    ,  0.693785,  0.401299],
   [ 0.      ,  0.75    ,  0.693785,  0.401299],
   [ 0.      ,  0.25    ,  0.306215,  0.598701],
   [ 0.      ,  0.75    ,  0.306215,  0.598701],
   [ 0.      ,  0.25    ,  0.693785,  0.598701],
   [ 0.      ,  0.75    ,  0.693785,  0.598701],
   [ 0.      ,  0.117635,  0.5     ,  0.226923],
   [ 0.      ,  0.117635,  0.5     ,  0.773077],
   [ 0.      ,  0.      ,  0.114216,  0.260279],
   [ 0.      ,  0.      ,  0.114216,  0.739721],
   [ 0.      ,  0.      ,  0.885784,  0.260279],
   [ 0.      ,  0.      ,  0.885784,  0.739721],
   [ 0.      ,  0.5     ,  0.885784,  0.260279],
   [ 0.      ,  0.5     ,  0.885784,  0.739721],
   [ 0.      ,  0.25    ,  0.306215,  0.401299],
   [ 0.      ,  0.25    ,  0.306215,  0.598701],
   [ 0.      ,  0.75    ,  0.306215,  0.401299],
   [ 0.      ,  0.75    ,  0.306215,  0.598701],
   [ 0.      ,  0.75    ,  0.693785,  0.401299],
   [ 0.      ,  0.75    ,  0.693785,  0.598701]])

b = nparray([[ 1.      ,  0.5     ,  0.5     ,  0.271149],
   [ 1.      ,  0.      ,  0.5     ,  0.271149],
   [ 1.      ,  0.5     ,  0.      ,  0.303063],
   [ 1.      ,  0.      ,  0.      ,  0.303063],
   [ 2.      ,  0.75    ,  0.      ,  0.358071],
   [ 2.      ,  0.25    ,  0.      ,  0.358071],
   [ 1.      ,  0.75    ,  0.306215,  0.358071],
   [ 1.      ,  0.25    ,  0.306215,  0.358071],
   [ 2.      ,  0.5     ,  0.5     ,  0.358071],
   [ 2.      ,  0.      ,  0.5     ,  0.358071],
   [ 1.      ,  0.75    ,  0.693785,  0.358071],
   [ 1.      ,  0.25    ,  0.693785,  0.358071],
   [ 1.      ,  0.5     ,  0.      ,  0.413078],
   [ 1.      ,  0.      ,  0.      ,  0.413078],
   [ 1.      ,  0.5     ,  0.5     ,  0.444992],
   [ 1.      ,  0.      ,  0.5     ,  0.444992],
   [ 2.      ,  0.5     ,  0.      ,  0.5     ],
   [ 2.      ,  0.      ,  0.      ,  0.5     ],
   [ 1.      ,  0.75    ,  0.193785,  0.5     ],
   [ 1.      ,  0.25    ,  0.193785,  0.5     ],
   [ 2.      ,  0.75    ,  0.5     ,  0.5     ],
   [ 2.      ,  0.25    ,  0.5     ,  0.5     ],
   [ 1.      ,  0.75    ,  0.806215,  0.5     ],
   [ 1.      ,  0.25    ,  0.806215,  0.5     ],
   [ 1.      ,  0.5     ,  0.5     ,  0.555008],
   [ 1.      ,  0.      ,  0.5     ,  0.555008],
   [ 1.      ,  0.5     ,  0.      ,  0.586922],
   [ 1.      ,  0.      ,  0.      ,  0.586922],
   [ 2.      ,  0.75    ,  0.      ,  0.641929],
   [ 2.      ,  0.25    ,  0.      ,  0.641929],
   [ 1.      ,  0.75    ,  0.306215,  0.641929],
   [ 1.      ,  0.25    ,  0.306215,  0.641929],
   [ 2.      ,  0.5     ,  0.5     ,  0.641929],
   [ 2.      ,  0.      ,  0.5     ,  0.641929],
   [ 1.      ,  0.75    ,  0.693785,  0.641929],
   [ 1.      ,  0.25    ,  0.693785,  0.641929],
   [ 1.      ,  0.5     ,  0.      ,  0.696937],
   [ 1.      ,  0.      ,  0.      ,  0.696937],
   [ 1.      ,  0.5     ,  0.5     ,  0.728851],
   [ 1.      ,  0.      ,  0.5     ,  0.728851],
   [ 0.      ,  0.617635,  0.5     ,  0.238728],
   [ 0.      ,  0.117635,  0.5     ,  0.238728],
   [ 0.      ,  0.5     ,  0.114216,  0.270642],
   [ 0.      ,  0.      ,  0.114216,  0.270642],
   [ 0.      ,  0.5     ,  0.      ,  0.270642],
   [ 0.      ,  0.      ,  0.      ,  0.270642],
   [ 0.      ,  0.117635,  0.5     ,  0.761272],
   [ 0.      ,  0.617635,  0.5     ,  0.761272],
   [ 0.      ,  0.      ,  0.114216,  0.729358],
   [ 0.      ,  0.5     ,  0.114216,  0.729358],
   [ 0.      ,  0.      ,  0.      ,  0.729358],
   [ 0.      ,  0.5     ,  0.      ,  0.729358],
   [ 0.      ,  0.75    ,  0.306215,  0.401299],
   [ 0.      ,  0.25    ,  0.306215,  0.401299],
   [ 0.      ,  0.75    ,  0.693785,  0.401299],
   [ 0.      ,  0.25    ,  0.693785,  0.401299],
   [ 0.      ,  0.75    ,  0.306215,  0.598701],
   [ 0.      ,  0.25    ,  0.306215,  0.598701],
   [ 0.      ,  0.75    ,  0.693785,  0.598701],
   [ 0.      ,  0.25    ,  0.693785,  0.598701],
   [ 0.      ,  0.117635,  0.5     ,  0.226923],
   [ 0.      ,  0.117635,  0.5     ,  0.773077],
   [ 0.      ,  0.      ,  0.114216,  0.260279],
   [ 0.      ,  0.      ,  0.114216,  0.739721],
   [ 0.      ,  0.      ,  0.885784,  0.260279],
   [ 0.      ,  0.      ,  0.885784,  0.739721],
   [ 0.      ,  0.75    ,  0.306215,  0.401299],
   [ 0.      ,  0.75    ,  0.306215,  0.598701],
   [ 0.      ,  0.25    ,  0.306215,  0.401299],
   [ 0.      ,  0.25    ,  0.306215,  0.598701],
   [ 0.      ,  0.75    ,  0.693785,  0.401299],
   [ 0.      ,  0.75    ,  0.693785,  0.598701],
   [ 0.      ,  0.25    ,  0.693785,  0.401299],
   [ 0.      ,  0.25    ,  0.693785,  0.598701]])

2 个答案:

答案 0 :(得分:2)

这是一种方法,考虑{{3}}将每一行减少为一个标量,然后简单地进行排序和比较,就像这样 -

def structs_are_equiv_dotreduc(a,b):
    scale = 10000**np.arange(1,4)
    a0 = np.sort(a[:,1:].dot(scale).astype(int))
    b0 = np.sort(b[:,1:].dot(scale).astype(int))    
    return (a0 == b0).all()

运行时测试 -

In [538]: # Setup inputs with b array just a row-permuted version of a
     ...: a = np.random.rand(100,4)
     ...: b = a[np.random.permutation(a.shape[0])]
     ...: 

In [539]: %timeit structs_are_equiv(a,b)
10000 loops, best of 3: 117 µs per loop

In [540]: %timeit structs_are_equiv_dotreduc(a,b)
10000 loops, best of 3: 42.7 µs per loop

答案 1 :(得分:0)

来自numpy_indexed包的

npi.sort应该比您当前的解决方案更快;虽然如果他的假设确实存在,那么divakar的解决方案应该更快。