删除NumPy数组中包含重复项的行

时间:2011-09-15 23:01:46

标签: python performance numpy vectorization

我有一个(N,3)数组的numpy值:

>>> vals = numpy.array([[1,2,3],[4,5,6],[7,8,7],[0,4,5],[2,2,1],[0,0,0],[5,4,3]])
>>> vals
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 7],
       [0, 4, 5],
       [2, 2, 1],
       [0, 0, 0],
       [5, 4, 3]])

我想从数组中删除具有重复值的行。例如,上述数组的结果应为:

>>> duplicates_removed
array([[1, 2, 3],
       [4, 5, 6],
       [0, 4, 5],
       [5, 4, 3]])

我不确定如何在没有循环的情况下有效地使用numpy(数组可能非常大)。谁知道我怎么能这样做?

5 个答案:

答案 0 :(得分:10)

这是一个选项:

import numpy
vals = numpy.array([[1,2,3],[4,5,6],[7,8,7],[0,4,5],[2,2,1],[0,0,0],[5,4,3]])
a = (vals[:,0] == vals[:,1]) | (vals[:,1] == vals[:,2]) | (vals[:,0] == vals[:,2])
vals = numpy.delete(vals, numpy.where(a), axis=0)

答案 1 :(得分:3)

这是处理通用列数的方法,仍然是矢量化方法 -

def rows_uniq_elems(a):
    a_sorted = np.sort(a,axis=-1)
    return a[(a_sorted[...,1:] != a_sorted[...,:-1]).all(-1)]

步骤:

  • 沿每行排序。

  • 查找每行中连续元素之间的差异。因此,具有至少一个零微分的任何行表示重复元素。我们将使用它来获取有效行的掩码。因此,最后一步是使用掩码简单地从输入数组中选择有效行。

示例运行 -

In [49]: a
Out[49]: 
array([[1, 2, 3, 7],
       [4, 5, 6, 7],
       [7, 8, 7, 8],
       [0, 4, 5, 6],
       [2, 2, 1, 1],
       [0, 0, 0, 3],
       [5, 4, 3, 2]])

In [50]: rows_uniq_elems(a)
Out[50]: 
array([[1, 2, 3, 7],
       [4, 5, 6, 7],
       [0, 4, 5, 6],
       [5, 4, 3, 2]])

答案 2 :(得分:2)

numpy.array([v for v in vals if len(set(v)) == len(v)])

请注意,这仍然在幕后循环。你无法避免这种情况。但它应该可以正常工作,即使是数百万行。

答案 3 :(得分:2)

六年过去了,但这个问题对我有帮助,所以我对Divakar,Benjamin,Marcelo Cantos和Curtis Patrick给出的答案进行了速度比较。

import numpy as np
vals = np.array([[1,2,3],[4,5,6],[7,8,7],[0,4,5],[2,2,1],[0,0,0],[5,4,3]])

def rows_uniq_elems1(a):
    idx = a.argsort(1)
    a_sorted = a[np.arange(idx.shape[0])[:,None], idx]
    return a[(a_sorted[:,1:] != a_sorted[:,:-1]).all(-1)]

def rows_uniq_elems2(a):
    a = (a[:,0] == a[:,1]) | (a[:,1] == a[:,2]) | (a[:,0] == a[:,2])
    return np.delete(a, np.where(a), axis=0)

def rows_uniq_elems3(a):
    return np.array([v for v in a if len(set(v)) == len(v)])

def rows_uniq_elems4(a):
    return np.array([v for v in a if len(np.unique(v)) == len(v)])

结果:

%timeit rows_uniq_elems1(vals)
10000 loops, best of 3: 67.9 µs per loop

%timeit rows_uniq_elems2(vals)
10000 loops, best of 3: 156 µs per loop

%timeit rows_uniq_elems3(vals)
1000 loops, best of 3: 59.5 µs per loop

%timeit rows_uniq_elems(vals)
10000 loops, best of 3: 268 µs per loop

似乎使用set节拍numpy.unique。在我的情况下,我需要在更大的阵列上执行此操作:

bigvals = np.random.randint(0,10,3000).reshape([3,1000])

%timeit rows_uniq_elems1(bigvals)
10000 loops, best of 3: 276 µs per loop

%timeit rows_uniq_elems2(bigvals)
10000 loops, best of 3: 192 µs per loop

%timeit rows_uniq_elems3(bigvals)
10000 loops, best of 3: 6.5 ms per loop

%timeit rows_uniq_elems4(bigvals)
10000 loops, best of 3: 35.7 ms per loop

没有列表推导的方法要快得多。但是,行数是硬编码的,很难扩展到三列以上,所以在我的情况下,至少列表理解是最好的答案。

已编辑,因为我混淆了bigvals

中的行和列

答案 4 :(得分:1)

与Marcelo相同,但我认为使用numpy.unique()代替set()可能会得到您正在拍摄的内容。

numpy.array([v for v in vals if len(numpy.unique(v)) == len(v)])