根据字典中的存在过滤numpy数组

时间:2018-06-05 09:58:30

标签: python numpy

我有一个numpy ndarray如下:

import numpy as np
x = np.array([[1, 2, 1], [4, 5, 7], [3, 2, 3]])

我有一个字典,它保留了一些类ID,如下所示:

k = {1: None, 2: None, 3: None}

现在,该numpy数组的最后一列包含类ID。所以我想要做的是根据字典中是否存在类ID来过滤numpy数组。因此,过滤该输入数组会将第1行和第3行设为7不在字典中。

所以我把classes列作为:

cls = x[:, -1]

现在,我不知道如何使用它来过滤x数组而不循环遍历这个并创建另一个数组。

2 个答案:

答案 0 :(得分:2)

这是使用numpy.in1d的一种方式:

keys = list(k.keys())
res = x[np.in1d(x[:, -1], keys)]

print(res)

[[1 2 1]
 [3 2 3]]

答案 1 :(得分:1)

我做的事情如下:

import numpy as np
x = np.array([[1, 2, 1], [4, 5, 7], [3, 2, 3]])
k = {1: None, 2: None, 3: None}
classes = [i for i in x[:,-1] if i in k.keys()]
classes = np.array(classes)
print(classes)

这应该返回1和3,而不是7,因为我们只查找x中的最后一行数据值。这将返回k中的值列表,如果您愿意,可以将其组成一个数组。