有效地找到可使数组等于其自身排列的索引

时间:2018-10-04 01:44:55

标签: numpy permutation

我正在寻找一些函数,这些函数可以找到使数组等于其自身排列的索引。

假设p1是一维Numpy数组,不包含重复项。假设p2p1的排列(重新排序)。

我想要一个函数find_position_in_original,以使p2[find_position_in_original(p2, p1)]p1相同。

例如:

p1 = np.array(['a', 'e', 'c', 'f'])
p2 = np.array(['e', 'f', 'a', 'c'])

find_position_in_permutation(p1, p2)应该返回:

[2, 0, 1, 3]

因为p2[[2, 0, 1, 3]]p1相同。

您可以使用列表以暴力方式进行此操作:

def find_position_in_permutation(original, permutation):
    original = list(original)
    permutation = list(permutation)
    return list(map(permutation.index, original))

但是我想知道是否有算法更有效。这个似乎是O(N^2)


当前答案的基准

import numpy as np
from string import ascii_lowercase

n = 100

letters = np.array([*ascii_lowercase])
p1 = np.random.choice(letters, size=n)
p2 = np.random.permutation(p1)
p1l = p1.tolist()
p2l = p2.tolist()

def find_pos_in_perm_1(original, permutation):
    """ My original solution """
    return list(map(permutation.index, original))

def find_pos_in_perm_2(original, permutation):
    """ Eric Postpischil's solution, using a dict as a lookup table """
    tbl = {val: ix for ix, val in enumerate(permutation)}
    return [tbl[val] for val in original]

def find_pos_in_perm_3(original, permutation):
    """ Paul Panzer's solution, using an array as a lookup table """
    original_argsort = np.argsort(original)
    permutation_argsort = np.argsort(permutation)
    tbl = np.empty_like(original_argsort)
    tbl[original_argsort] = permutation_argsort
    return tbl

%timeit find_pos_in_perm_1(p1l, p2l)
# 40.5 µs ± 1.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit find_pos_in_perm_2(p1l, p2l)
# 10 µs ± 171 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit find_pos_in_perm_3(p1, p2)
# 6.38 µs ± 157 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

1 个答案:

答案 0 :(得分:2)

您可以使用argsort执行O(N log N):

>>> import numpy as np
>>> from string import ascii_lowercase
>>> 
>>> letters = np.array([*ascii_lowercase])
>>> p1, p2 = map(np.random.permutation, 2*(letters,))
>>> 
>>> o1, o2 = map(np.argsort, (p1, p2))
>>> o12, o21 = map(np.empty_like, (o1, o2))
>>> o12[o1], o21[o2] = o2, o1
>>> 
>>> print(np.all(p1[o21] == p2))
True
>>> print(np.all(p2[o12] == p1))
True

使用Python字典的O(N)解决方案:

>>> import operator as op
>>>    
>>> l1, l2 = map(op.methodcaller('tolist'), (p1, p2))
>>> 
>>> s12 = op.itemgetter(*l1)({k: v for v, k in enumerate(l2)})
>>> print(np.all(s12 == o12))
True

一些时间:

26 elements
argsort      0.004 ms
dict         0.003 ms
676 elements
argsort      0.096 ms
dict         0.075 ms
17576 elements
argsort      4.366 ms
dict         2.915 ms
456976 elements
argsort    191.376 ms
dict       230.459 ms

基准代码:

import numpy as np
from string import ascii_lowercase
import operator as op
from timeit import timeit

L1 = np.array([*ascii_lowercase], object)
L2 = np.add.outer(L1, L1).ravel()
L3 = np.add.outer(L2, L1).ravel()
L4 = np.add.outer(L2, L2).ravel()
letters = (*map(op.methodcaller('astype', str), (L1, L2, L3, L4)),)

def use_argsort(p1, p2):
    o1, o2 = map(np.argsort, (p1, p2))
    o12 = np.empty_like(o1)
    o12[o1] = o2
    return o12

def use_dict(l1, l2):
    return op.itemgetter(*l1)({k: v for v, k in enumerate(l2)})

for L, N in zip(letters, (1000, 1000, 200, 4)):
    print(f'{len(L)} elements')
    p1, p2 = map(np.random.permutation, (L, L))
    l1, l2 = map(op.methodcaller('tolist'), (p1, p2))
    T = (timeit(lambda: f(i1, i2), number=N)*1000/N for f, i1, i2 in (
        (use_argsort, p1, p2), (use_dict, l1, l2)))
    for m, t in zip(('argsort', 'dict   '), T):
        print(m, f'{t:10.3f} ms')