我们假设我们有一个像这样构建的numpy数组:
import numpy as np
data = [(1, 2), (1, 3), (1, 4), (1, 5), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6)]
data1 = [1 for i in data]
table = np.asarray(list(zip(data, data1, data1, data1, data1))).transpose()
导致:
[[(1, 2) (1, 3) (1, 4) (1, 5) (2, 1) (2, 2) (2, 3) (2, 4) (2, 5) (2, 6)]
[1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1]
[1 1 1 1 1 1 1 1 1 1]]
现在还有另一个列表test = [(1, 2), (1, 3), (1, 4)]
如果第一行中的元组与测试列表中的元组不匹配,我想过滤表中的列。
我希望它的结果如下:
[[(1, 2) (1, 3) (1, 4)]
[1 1 1]
[1 1 1]
[1 1 1]
[1 1 1]]
我试过这段代码:
mask = np.in1d(table[0, :], test)
table = table[:, mask]
print(table)
但它产生了一个空列表 有什么建议? 谢谢
答案 0 :(得分:1)
#use a bool array to select columns
table[:,np.array([e in test for e in table[0]])]
Out[306]:
array([[(1, 2), (1, 3), (1, 4)],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]], dtype=object)
答案 1 :(得分:0)
不太优雅,但我发现这更直观:
tb = set(table[0,:].reshape(table[0,:].size))
table[:,[i for i, t in enumerate(tb) if tuple(t) in test]]