我有两个numpy数组,A和B,表示2D平面中点的坐标。我们假设A是10000乘2,B是20000乘2。两者都有float64
dtype。
我想找出第一个数组中的哪个点A在第二个数组中(B)。使用for
循环执行此操作会非常慢。我想出了以下广播方案来进行比较(暂时忽略浮点平等与亲密度问题):
x_bool_array = A[:,0][numpy.newaxis,...] == B[:,0][...,numpy.newaxis]
y_bool_array = A[:,1][numpy.newaxis,...] == B[:,1][...,numpy.newaxis]
bool_array = numpy.logical_and(x_bool_array, y_bool_array)
indices = numpy.where(bool_array)
然而,这将导致非常大的20000×10000个布尔数组,这些数组主要是稀疏的,即True
s的数量远小于False
的数量。
我想知道是否有办法让他们通过某些开关或财产稀疏?或者,如果有更好的方法来做到这一点,那就是快速且不会消耗大量内存? (分段做可能是另一种选择,但我想我也在寻找优雅,除了速度和低内存)。
修改:回应@Tai的评论澄清,让我们举一个小例子:
A = numpy.array([[0.1, 0.2], [0.34, 0.44], [0.5, 0.6]])
B = numpy.array([[0.05, 0.05], [0.1, 0.2], [0.7, 0.8], [0.5, 0.6]])
换句话说,A是3个2D点(3乘2)的数组,B是4个2D点(4乘2)的数组。
我们可以看到B[1,:]
与A[0,:]
相同,B[3,:]
与A[2,:]
相同。所以我们有两场比赛。最终结果indices
如下:
(array([1, 3]), array([0, 2]))
编辑2 :之前我说过分段是一种选择。我尝试过它并没有更好。本质上,我将两个数组中的一个分成100个块,在每个块上对整个第二个数组进行逻辑比较,并在for
循环中合并结果。不幸的是,没有办法让解释器知道它可以使用以前的内存(即,你不能显式地控制垃圾收集器,或者至少它不会是非常惯用的python / numpy),并且分配器不断分配新的内存对于每个新的块。
答案 0 :(得分:0)
如果您不介意,大熊猫将是一种解决方法。
import pandas as pd
import numpy as np
A = np.array([[0.1, 0.2], [0.34, 0.44], [0.5, 0.6]])
B = np.array([[0.05, 0.05], [0.1, 0.2], [0.7, 0.8], [0.5, 0.6]])
dfA = pd.DataFrame(A, columns=["v1", "v2"]).reset_index()
dfB = pd.DataFrame(B, columns=["v1", "v2"]).reset_index()
common_vals = pd.merge(dfA, dfB, how='inner', on=['v1','v2'])
index_x v1 v2 index_y
0 0 0.1 0.2 1
1 2 0.5 0.6 3
然后通过传递您需要的列名列表,index_x
,选择index_y
和["index_x", "index_y"]
两列。
common_vals[["index_x", "index_y"]].as_matrix()
Out: array([[0, 1],
[2, 3]])
答案 1 :(得分:0)
从根本上说,这是一个最近邻搜索,您正在寻找距离为零的邻居。您可以使用适当的数据结构非常有效地完成此操作;在这里,KD-Tree是最好的选择。
以下是使用您提供的阵列的快速示例:
from scipy.spatial import cKDTree
dist, ind = cKDTree(B).query(A, 1)
results = (ind[dist == 0], np.where(dist == 0)[0])
results
# (array([1, 3]), array([0, 2]))
这种方法对于非常大的数组应该可以很好地扩展,因为它避免了直接方法所需的所有N x M
比较。对于您建议的大型阵列的大小,这将在不到20毫秒内完成:
A = np.random.randint(0, 1000, (10000, 2))
B = np.random.randint(0, 1000, (20000, 2))
%%timeit
dist, ind = cKDTree(B).query(A, 1)
results = ind[dist == 0], np.where(dist == 0)[0]
# 16.9 ms ± 530 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)