我们如何根据其项目为元组的行的掩码来过滤列?

时间:2017-05-02 06:12:40

标签: python numpy

我们假设我们有一个像这样构建的numpy数组:

import numpy as np

data = [(1, 2), (1, 3), (1, 4), (1, 5), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6)]
data1 = [1 for i in data]

table = np.asarray(list(zip(data, data1, data1, data1, data1))).transpose()

导致:

[[(1, 2) (1, 3) (1, 4) (1, 5) (2, 1) (2, 2) (2, 3) (2, 4) (2, 5) (2, 6)]
 [1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1]]

现在还有另一个列表test = [(1, 2), (1, 3), (1, 4)] 如果第一行中的元组与测试列表中的元组不匹配,我想过滤表中的列。
我希望它的结果如下:

[[(1, 2) (1, 3) (1, 4)]
 [1 1 1]
 [1 1 1]
 [1 1 1]
 [1 1 1]]

我试过这段代码:

mask = np.in1d(table[0, :], test)
table = table[:, mask]
print(table)

但它产生了一个空列表 有什么建议? 谢谢

2 个答案:

答案 0 :(得分:1)

#use a bool array to select columns
table[:,np.array([e in test for e in table[0]])]
Out[306]: 
array([[(1, 2), (1, 3), (1, 4)],
       [1, 1, 1],
       [1, 1, 1],
       [1, 1, 1],
       [1, 1, 1]], dtype=object)

答案 1 :(得分:0)

不太优雅,但我发现这更直观:

tb = set(table[0,:].reshape(table[0,:].size))
table[:,[i for i, t in enumerate(tb) if tuple(t) in test]]