我想要的东西类似于Numpy中带有两个数组的SQL WHERE
表达式。假设我有两个这样的数组:
import numpy as np
dt = np.dtype([('f1', np.uint8), ('f2', np.uint8), ('f3', np.float_)])
a = np.rec.fromarrays([[3, 4, 4, 7, 9, 9],
[1, 5, 5, 4, 2, 2],
[2.0, -4.5, -4.5, 1.3, 24.3, 24.3]], dtype=dt)
b = np.rec.fromarrays([[ 1, 4, 7, 9, 9],
[ 7, 5, 4, 2, 2],
[-3.5, -4.5, 1.3, 24.3, 24.3]], dtype=dt)
我想返回原始数组的索引,以便覆盖每个可能的匹配对。另外,我想利用两个数组都已排序的事实,因此不需要最坏情况的O(mn)
算法。在这种情况下,由于(4, 5, -4.5)
匹配,但在第一个数组中出现两次,它将在结果索引中出现两次,并且由于(9, 2, 24.3)
在两者中出现两次,因此总共会出现4次。由于(3, 1, 2.0)
未出现在第二个数组中,因此将跳过它,第二个数组中的(1, 7, -3.5)
也将被跳过。该函数应适用于任何dtype
。
在这种情况下,结果将是这样的:
a_idx, b_idx = match_arrays(a, b)
a_idx = np.array([1, 2, 3, 4, 4, 5, 5])
b_idx = np.array([1, 1, 2, 3, 4, 3, 4])
具有相同输出的另一个示例:
dt2 = np.dtype([('f1', np.uint8), ('f2', dt)])
a2 = np.rec.fromarrays([[3, 4, 4, 7, 9, 9], a], dtype=dt2)
b2 = np.rec.fromarrays([[1, 4, 7, 9, 9], b], dtype=dt2)
我有一个纯Python实现,但它在我的用例中作为糖蜜很慢。我希望有更多的矢量化。这是我到目前为止所做的:
def match_arrays(a, b):
len_a = len(a)
len_b = len(b)
a_idx = []
b_idx = []
i, j = 0, 0
first_matched_j = 0
while i < len_a and j < len_b:
matched = False
j = first_matched_j
while j < len_b and a[i] == b[j]:
a_idx.append(i)
b_idx.append(j)
if not matched:
matched = True
first_matched_j = j
j += 1
else:
i += 1
j = first_matched_j
while i < len_a and j < len_b and a[i] > b[j]:
j += 1
first_matched_j = j
while i < len_a and j < len_b and a[i] < b[j]:
i += 1
return np.array(a_idx), np.array(b_idx)
修改:在Divakar中指出answer,我可以使用a_idx, b_idx = np.where(np.equal.outer(a, b))
。但是,这似乎是我想通过预先排序数组来避免的最坏情况O(mn)
解决方案。特别是,如果没有任何重复,那么O(m + n)
会很棒。
编辑2: Paul Panzer如果只使用Numpy,则answer不是O(m + n)
,但通常会更快。另外,他提供了一个O(m + n)
答案,所以我接受了那个答案。我很快就会使用timeit
发布效果比较。
编辑3:以下是效果结果,如承诺:
╔════════════════╦═══════════════════╦═══════════════════╦═══════════════════╦══════════════════╦═══════════════════╗
║ User ║ Version ║ n = 10 ** 2 ║ n = 10 ** 4 ║ n = 10 ** 6 ║ n = 10 ** 8 ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ Paul Panzer ║ USE_HEAPQ = False ║ 115 µs ± 385 ns ║ 793 µs ± 8.43 µs ║ 105 ms ± 1.57 ms ║ 18.2 s ± 116 ms ║
║ ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ ║ USE_HEAPQ = True ║ 189 µs ± 3.6 µs ║ 6.38 ms ± 28.8 µs ║ 650 ms ± 2.49 ms ║ 1min 11s ± 420 ms ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ SigmaPiEpsilon ║ Generator ║ 936 µs ± 1.52 µs ║ 9.17 s ± 57 ms ║ N/A ║ N/A ║
║ ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ ║ for loop ║ 144 µs ± 526 ns ║ 15.6 ms ± 18.6 µs ║ 1.74 s ± 33.9 ms ║ N/A ║
╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ Divakar ║ np.where ║ 39.1 µs ± 281 ns ║ 302 ms ± 4.49 ms ║ Out of memory ║ N/A ║
║ ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ ║ recarrays 1 ║ 69.9 µs ± 491 ns ║ 1.6 ms ± 24.2 µs ║ 230 ms ± 3.52 ms ║ 41.5 s ± 543 ms ║
║ ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
║ ║ recarrays 2 ║ 82.6 µs ± 1.01 µs ║ 1.4 ms ± 4.51 µs ║ 212 ms ± 2.59 ms ║ 36.7 s ± 900 ms ║
╚════════════════╩═══════════════════╩═══════════════════╩═══════════════════╩══════════════════╩═══════════════════╝
看起来Paul Panzer的answer与USE_HEAPQ = False
一起获胜。我希望USE_HEAPQ = True
赢得大型输入,因为它是O(m + n)
,但事实证明并非如此。另一个评论是,USE_HEAPQ = False
版本使用的内存较少,最大为5.79 GB,USE_HEAPQ = True
为n = 10 ** 8
为10.18 GB。请记住,这是进程内存,包括控制台的输入和其他内容。 Divakar的重新排列答案1使用了8.42 GB的内存,重新调用了答案2使用了10.61 GB。
答案 0 :(得分:2)
方法#1:基于广播的方法
使用两个数组之间的outer
相等比较来利用向量化的broadcasting
,然后获得行,列索引,这将非常需要匹配对应于两个数组的索引 -
a_idx, b_idx = np.where(a[:,None]==b)
a_idx, b_idx = np.where(np.equal.outer(a,b))
我们也可以使用np.nonzero
代替np.where
。
方法#2:具体案例解决方案
没有重复和排序的输入数组,我们可以使用np.searchsorted
,就像这样 -
idx0 = np.searchsorted(a,b)
idx1 = np.searchsorted(b,a)
idx0[idx0==len(a)] = 0
idx1[idx1==len(b)] = 0
a_idx = idx0[a[idx0] == b]
b_idx = idx1[b[idx1] == a]
稍微修改一下,可能会更有效 -
idx0 = np.searchsorted(a,b)
idx0[idx0==len(a)] = 0
a_idx = idx0[a[idx0] == b]
b_idx = np.searchsorted(b,a[a_idx])
方法#3:通用案例
这是一般情况的解决方案(允许重复) -
def findwhere(a, b):
c = np.bincount(b, minlength=a.max()+1)[a]
a_idx1 = np.repeat(np.flatnonzero(c),c[c!=0])
b_idx1 = np.searchsorted(b,a[a_idx1])
m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
idx11 = np.flatnonzero(m1[1:] != m1[:-1])
id_arr = m1.astype(int)
id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
b_idx1 += id_arr.cumsum()[:-1]
return a_idx1, b_idx1
使用@Paul Panzer的soln中的mock_data
设置输入:
In [295]: a, b = mock_data(1000000)
# @Paul Panzer's soln
In [296]: %timeit sqlwhere(a, b) # USE_HEAPQ = False
10 loops, best of 3: 118 ms per loop
# Approach #3 from this post
In [297]: %timeit findwhere(a,b)
10 loops, best of 3: 61.7 ms per loop
将重新排列(uint8数据)转换为1D
数组的实用程序
def convert_recarrays_to_1Darrs(a, b):
a2D = a.view('u1').reshape(-1,2)
b2D = b.view('u1').reshape(-1,2)
s = max(a2D[:,0].max(), b2D[:,0].max())+1
a1D = s*a2D[:,1] + a2D[:,0]
b1D = s*b2D[:,1] + b2D[:,0]
return a1D, b1D
示例运行 -
In [90]: dt = np.dtype([('f1', np.uint8), ('f2', np.uint8)])
...: a = np.rec.fromarrays([[3, 4, 4, 7, 9, 9],
...: [1, 5, 5, 4, 2, 2]], dtype=dt)
...: b = np.rec.fromarrays([[1, 4, 7, 9, 9],
...: [7, 5, 4, 2, 2]], dtype=dt)
In [91]: convert_recarrays_to_1Darrs(a, b)
Out[91]:
(array([13, 54, 54, 47, 29, 29], dtype=uint8),
array([71, 54, 47, 29, 29], dtype=uint8))
涵盖rec-arrays
版本#1:
def findwhere_generic_v1(a, b):
cidx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
count = np.diff(cidx)
b_starts = b[cidx[:-1]]
a_starts = np.searchsorted(a,b_starts)
a_starts[a_starts==len(a)] = 0
valid_mask = (b_starts == a[a_starts])
count_valid = count[valid_mask]
idx2m0 = np.searchsorted(a,b_starts[valid_mask],'right')
idx1m0 = a_starts[valid_mask]
id_arr = np.zeros(len(a)+1, dtype=int)
id_arr[idx2m0] -= 1
id_arr[idx1m0] += 1
n = idx2m0 - idx1m0
r1 = np.flatnonzero(id_arr.cumsum()!=0)
r2 = np.repeat(count_valid,n)
a_idx1 = np.repeat(r1, r2)
b_idx1 = np.searchsorted(b,a[a_idx1])
m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
idx11 = np.flatnonzero(m1[1:] != m1[:-1])
id_arr = m1.astype(int)
id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
b_idx1 += id_arr.cumsum()[:-1]
return a_idx1, b_idx1
版本#2:
def findwhere_generic_v2(a, b):
cidx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
count = np.diff(cidx)
b_starts = b[cidx[:-1]]
idxx = np.flatnonzero(np.r_[True,a[1:] != a[:-1],True])
av = a[idxx[:-1]]
idxxs = np.searchsorted(av,b_starts)
idxxs[idxxs==len(av)] = 0
valid_mask0 = av[idxxs] == b_starts
starts = idxx[idxxs]
stops = idxx[idxxs+1]
idx1m0 = starts[valid_mask0]
idx2m0 = stops[valid_mask0]
count_valid = count[valid_mask0]
id_arr = np.zeros(len(a)+1, dtype=int)
id_arr[idx2m0] -= 1
id_arr[idx1m0] += 1
n = idx2m0 - idx1m0
r1 = np.flatnonzero(id_arr.cumsum()!=0)
r2 = np.repeat(count_valid,n)
a_idx1 = np.repeat(r1, r2)
b_idx1 = np.searchsorted(b,a[a_idx1])
m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
idx11 = np.flatnonzero(m1[1:] != m1[:-1])
id_arr = m1.astype(int)
id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
b_idx1 += id_arr.cumsum()[:-1]
return a_idx1, b_idx1
答案 1 :(得分:2)
这是一个O(n)-ish解决方案(因为如果重复很长,它显然不能是O(n))。在实践中,根据输入长度,可能通过牺牲O(n)并用稳定的heapq.merge
替换np.argsort
来节省一点。目前,N = 10 ^ 6需要大约一秒钟。
代码:
import numpy as np
USE_HEAPQ = True
def sqlwhere(a, b):
asw = np.r_[0, 1 + np.flatnonzero(a[:-1]!=a[1:]), len(a)]
bsw = np.r_[0, 1 + np.flatnonzero(b[:-1]!=b[1:]), len(b)]
al, bl = np.diff(asw), np.diff(bsw)
na, nb = len(al), len(bl)
abunq = np.r_[a[asw[:-1]], b[bsw[:-1]]]
if USE_HEAPQ:
from heapq import merge
m = np.fromiter(merge(range(na), range(na, na+nb), key=abunq.__getitem__), int, na+nb)
else:
m = np.argsort(abunq, kind='mergesort')
mv = abunq[m]
midx = np.flatnonzero(mv[:-1]==mv[1:])
ai, bi = m[midx], m[midx+1] - na
aic = np.r_[0, np.cumsum(al[ai])]
a_idx = np.ones((aic[-1],), dtype=int)
a_idx[aic[:-1]] = asw[ai]
a_idx[aic[1:-1]] -= asw[ai[:-1]] + al[ai[:-1]] - 1
a_idx = np.repeat(np.cumsum(a_idx), np.repeat(bl[bi], al[ai]))
bi = np.repeat(bi, al[ai])
bic = np.r_[0, np.cumsum(bl[bi])]
b_idx = np.ones((bic[-1],), dtype=int)
b_idx[bic[:-1]] = bsw[bi]
b_idx[bic[1:-1]] -= bsw[bi[:-1]] + bl[bi[:-1]] - 1
b_idx = np.cumsum(b_idx)
return a_idx, b_idx
def f_D(a, b):
return np.where(np.equal.outer(a,b))
def mock_data(n):
return np.cumsum(np.random.randint(0, 3, (2, n)), axis=1)
a = np.array([3, 4, 4, 7, 9, 9], dtype=np.uint8)
b = np.array([1, 4, 7, 9, 9], dtype=np.uint8)
# check correct
a, b = mock_data(1000)
ai0, bi0 = f_D(a, b)
ai1, bi1 = sqlwhere(a, b)
print(np.all(ai0 == ai1), np.all(bi0 == bi1))
# check fast
a, b = mock_data(1000000)
sqlwhere(a, b)
答案 2 :(得分:1)
具有生成器和列表推导的替代纯python实现。与代码相比,内存效率可能更高,但与numpy版本相比可能会更慢。对于排序数组,这将更快。
def pywheregen(a, b):
l = ((ia,ib) for ia,j in enumerate(a) for ib,k in enumerate(b) if j == k)
a_idx,b_idx = zip(*l)
return a_idx,b_idx
这是一个使用简单python for循环的替代版本,并考虑到数组已排序,以便它只检查它需要的对。
def pywhere(a, b):
l = []
a.sort()
b.sort()
match = 0
for ia,j in enumerate(a):
ib = match
while ib < len(b) and j >= b[ib]:
if j == b[ib]:
l.append(((ia,ib)))
if b[match] < b[ib]:
match = ib
ib += 1
a_ind,b_ind = zip(*l)
return a_ind, b_ind
我使用@Paul Panzer的mock_data()函数比较了时间,并将它与findwhere()
和f_D()
@divakar的np.outer方法进行了比较。 findwhere()
仍然表现最佳但pywhere()
并不是那么糟糕,因为它是纯粹的python。 pywheregen()
失败并且令人惊讶f_D()
需要更长的时间。他们都失败了N = 10 ^ 6。由于heapq
模块中存在无关的错误,我无法运行sqlwhere。
In [2]: a, b = mock_data(10000)
In [10]: %timeit -n 10 findwhere(a,b)
10 loops, best of 3: 1.62 ms per loop
In [11]: %timeit -n 10 pywhere(a,b)
10 loops, best of 3: 20.6 ms per loop
In [12]: %timeit pywheregen(a,b)
1 loop, best of 3: 12.7 s per loop
In [13]: %timeit -n 10 f_D(a,b)
10 loops, best of 3: 476 ms per loop
In [14]: a, b = mock_data(1000000)
In [15]: %timeit -n 10 findwhere(a,b)
10 loops, best of 3: 109 ms per loop
In [16]: %timeit -n 10 pywhere(a,b)
10 loops, best of 3: 2.51 s per loop