如果数组元素出现多次,则过滤2D numpy数组

时间:2016-08-03 20:12:43

标签: numpy multidimensional-array

我想删除共享2D数组中元素的行。例如:

array = [0 1]
        [2 3]
        [4 0]
        [0 4]

filtered_array = [2 3]

编辑:列位置无关紧要

2 个答案:

答案 0 :(得分:2)

这是使用NumPy broadcasting -

的矢量化方法
def filter_rows(arr):
    # Detect matches along same columns for both cols
    samecol_mask1 = arr[:,None,0] == arr[:,0]
    samecol_mask2 = arr[:,None,1] == arr[:,1]
    samecol_mask = np.triu(samecol_mask1 | samecol_mask2,1)

    # Detect matches across the two cols
    diffcol_mask = arr[:,None,0] == arr[:,1]

    # Get the combined matching mask
    mask = samecol_mask | diffcol_mask

    # Get the indices of the mask which gives us the row IDs that have matches
    # across either same or different columns. Delete those rows for output. 
    dup_rowidx = np.unique(np.argwhere(mask))
    return np.delete(arr,dup_rowidx,axis=0)

示例运行以展示各种方案

案例#1:跨越相同和不同列的多个匹配

In [313]: arr
Out[313]: 
array([[0, 1],
       [2, 3],
       [4, 0],
       [0, 4]])

In [314]: filter_rows(arr)
Out[314]: array([[2, 3]])

案例#2:沿着相同列匹配

In [319]: arr
Out[319]: 
array([[ 0,  1],
       [ 2,  3],
       [ 8, 10],
       [ 0,  4]])

In [320]: filter_rows(arr)
Out[320]: 
array([[ 2,  3],
       [ 8, 10]])

案例#3:沿不同列匹配

In [325]: arr
Out[325]: 
array([[ 0,  1],
       [ 2,  3],
       [ 8, 10],
       [ 7,  0]])

In [326]: filter_rows(arr)
Out[326]: 
array([[ 2,  3],
       [ 8, 10]])

案例#4:同一行中的匹配

In [331]: arr
Out[331]: 
array([[ 0,  1],
       [ 3,  3],
       [ 8, 10],
       [ 7,  0]])

In [332]: filter_rows(arr)
Out[332]: array([[ 8, 10]])

答案 1 :(得分:1)

只是@Divakar令人印象深刻的解决方案的替代品。这种方法在某种程度上更糟(尤其是效率),但对于非numpy-gurus来说可能更容易理解。

import numpy as np

def filter_(x):
    unique = np.unique(x) # 1
    unique_mapper = [np.where(x == z)[0] for z in unique] # 2
    filtered_unique_mapper = list(map(lambda x: x if len(x) > 1 else [], unique_mapper)) # 3
    all = np.concatenate(filtered_unique_mapper) # 4
    to_delete = np.unique(all) # 5
    return np.delete(x, all, axis=0)

# 1 get global unique values
# 2 for each unique value: get all rows with this value
#   -> multiple entries for one unique value: row's collide!
# 3 remove entries from above, if only <= 1 rows hold that unique value
# 4 collect all rows, which collided somehow
# 5 remove multiple entries from above