我有一个numpy ndarray如下:
import numpy as np
x = np.array([[1, 2, 1], [4, 5, 7], [3, 2, 3]])
我有一个字典,它保留了一些类ID,如下所示:
k = {1: None, 2: None, 3: None}
现在,该numpy数组的最后一列包含类ID。所以我想要做的是根据字典中是否存在类ID来过滤numpy数组。因此,过滤该输入数组会将第1行和第3行设为7
不在字典中。
所以我把classes列作为:
cls = x[:, -1]
现在,我不知道如何使用它来过滤x
数组而不循环遍历这个并创建另一个数组。
答案 0 :(得分:2)
这是使用numpy.in1d
的一种方式:
keys = list(k.keys())
res = x[np.in1d(x[:, -1], keys)]
print(res)
[[1 2 1]
[3 2 3]]
答案 1 :(得分:1)
我做的事情如下:
import numpy as np
x = np.array([[1, 2, 1], [4, 5, 7], [3, 2, 3]])
k = {1: None, 2: None, 3: None}
classes = [i for i in x[:,-1] if i in k.keys()]
classes = np.array(classes)
print(classes)
这应该返回1和3,而不是7,因为我们只查找x中的最后一行数据值。这将返回k中的值列表,如果您愿意,可以将其组成一个数组。