从另一个数组获取匹配索引

时间:2020-06-30 12:53:03

标签: python numpy

给出两个np.arrays;

a = np.array([1, 6, 5, 3, 8, 345, 34, 6, 2, 867])
b = np.array([867, 8, 34, 75])

我想得到一个尺寸与b相同的np.array,其中每个值都是b中的值出现在a中的索引,如果b中的值不存在于a中则为np.nan。

结果应该是

[9, 4, 6, nan]

a和b始终具有相同的尺寸数,但是尺寸的大小可能不同。

类似的东西

(伪代码)

c = np.where(b in a)

但适用于数组(“ in”无效)

我更喜欢“单线”或至少是完全在阵列级别上并且不需要循环的解决方案。

谢谢!

2 个答案:

答案 0 :(得分:6)

方法1

这里是np.searchsorted-

def find_indices(a,b,invalid_specifier=-1):
    # Search for matching indices for each b in sorted version of a. 
    # We use sorter arg to account for the case when a might not be sorted 
    # using argsort on a
    sidx = a.argsort()
    idx = np.searchsorted(a,b,sorter=sidx)

    # Remove out of bounds indices as they wont be matches
    idx[idx==len(a)] = 0

    # Get traced back indices corresponding to original version of a
    idx0 = sidx[idx]
    
    # Mask out invalid ones with invalid_specifier and return
    return np.where(a[idx0]==b, idx0, invalid_specifier)

给定样本的输出-

In [41]: find_indices(a, b, invalid_specifier=np.nan)
Out[41]: array([ 9.,  4.,  6., nan])

方法2

另一个基于lookup的正数-

def find_indices_lookup(a,b,invalid_specifier=-1):
    # Setup array where we will assign ranged numbers
    N = max(a.max(), b.max())+1
    lookup = np.full(N, invalid_specifier)
    
    # We index into lookup with b to trace back the positions. Non matching ones
    # would have invalid_specifier values as wount had been indexed by ranged ones
    lookup[a] = np.arange(len(a))
    indices  = lookup[b]
    return indices

基准化

效率未在问题中提到,但没有循环要求。使用尝试重新显示给定样本设置的设置进行测试,但按1000x进行放大:

In [98]: a = np.random.permutation(np.unique(np.random.randint(0,20000,10000)))

In [99]: b = np.random.permutation(np.unique(np.random.randint(0,20000,4000)))

# Solutions from this post
In [100]: %timeit find_indices(a,b,invalid_specifier=np.nan)
     ...: %timeit find_indices_lookup(a,b,invalid_specifier=np.nan)
1.35 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
220 µs ± 30.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# @Quang Hoang-soln2
In [101]: %%timeit
     ...: commons, idx_a, idx_b = np.intersect1d(a,b, return_indices=True)
     ...: orders = np.argsort(idx_b)
     ...: output = np.full(b.shape, np.nan)
     ...: output[orders] = idx_a[orders]
1.63 ms ± 59.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

# @Quang Hoang-soln1
In [102]: %%timeit
     ...: s = b == a[:,None]
     ...: np.where(s.any(0), np.argmax(s,0), np.nan)
137 ms ± 9.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

答案 1 :(得分:2)

您可以进行广播:

s = b == a[:,None]
np.where(s.any(0), np.argmax(s,0), np.nan)

输出:

array([ 9.,  4.,  6., nan])

更新另一种使用intersect1d的解决方案:

commons, idx_a, idx_b = np.intersect1d(a,b, return_indices=True)

orders = np.argsort(idx_b)

output = np.full(b.shape, np.nan)
output[orders] = idx_a[orders]
相关问题