Python:比较两个数组的元素

时间:2017-04-04 11:51:10

标签: python arrays performance numpy

我想比较两个numpy数组的元素,如果坐标之间的eucledean距离小于1且时间相同,则删除其中一个数组的元素。 data_CD4和data_CD8是数组。数组的元素是带有3D坐标的列表,时间是第4个元素(numpy.array([[x,y,z,time],[x,y,z,time] .....])。截止,这里是1。

for i in data_CD8:
        for m in data_CD4:
            if distance.euclidean(tuple(i[:3]),tuple(m[:3])) < co and i[3]==m[3] :
                data_CD8=np.delete(data_CD8, i, 0)

有更快的方法吗?第一个数组有5000个元素,第二个数组是2000,所以花了太多时间。

3 个答案:

答案 0 :(得分:2)

这应该是一个矢量化方法。

mask1 = np.sum((data_CD4[:, None, :3] - data_CD8[None, :, :3])**2, axis = -1) < co**2
mask2 = data_CD4[:, None, 3] == data_CD8[None, :, 3]
mask3 = np.any(np.logical_and(mask1, mask2), axis = 0)
data_CD8 = data_CD8[~mask3]

mask1应加快距离计算,因为它不需要平方根调用。 mask1mask2是2-D数组,我们通过np.any将其挤压到1d。在最后执行所有删除操作可以防止读/写堆。

速度测试:

a = np.random.randint(0, 10, (100, 3))

b = np.random.randint(0, 10, (100, 3))

%timeit cdist(a,b) < 5  #Divakar's answer
10000 loops, best of 3: 133 µs per loop

%timeit np.sum((a[None, :, :] - b[:, None, :]) ** 2, axis = -1) < 25  # My answer
1000 loops, best of 3: 418 µs per loop

即使添加了不必要的平方根,C编译的代码也会获胜。

答案 1 :(得分:2)

这是使用Scipy's cdist -

的矢量化方法
from scipy.spatial import distance

# Get eucliden distances between first three cols off data_CD8 and data_CD4
dists = distance.cdist(data_CD8[:,:3], data_CD4[:,:3])

# Get mask of those distances that are within co distance. This sets up the 
# first condition requirement as posted in the loopy version of original code.
mask1 = dists < co

# Take the third column off the two input arrays that represent the time values.
# Get the equality between all time values off data_CD8 against all time values
# off data_CD4. This sets up the second conditional requirement.
# We are adding a new axis with None, so that NumPY broadcasting
# would let us do these comparisons in a vectorized manner.
mask2 = data_CD8[:,3,None] == data_CD4[:,3]

# Combine those two masks and look for any match correponding to any 
# element off data_CD4. Since the masks are setup such that second axis
# represents data_CD4, we need numpy.any along axis=1 on the combined mask.
# A final inversion of mask is needed as we are deleting the ones that 
# satisfy these requirements.
mask3 = ~((mask1 & mask2).any(1))

# Finally, using boolean indexing to select the valid rows off data_CD8
out = data_CD8[mask3]

答案 2 :(得分:0)

如果您必须将data_CD4中的所有项目与data_CD8中的项目进行比较 从data_CD8中删除数据时,在每次迭代中使第二个迭代变小可能会更好,这当然取决于您最常见的 案件。

for m in data_CD4:
    for i in data_CD8:
        if distance.euclidean(tuple(i[3:]),tuple(m[3:])) < co and i[3]==m[3] :
            data_CD8 = np.delete(data_CD8, i, 0)

基于大O符号 - 因为这是O(n^2) - 我看不到更快 溶液