我有两个numpy数组,我试图删除第二个数组中所有值为-1的索引。
示例:
goldLabels: [12, 2, 0, 0, 0, 1, 5]
predictions: [12, 3, 0, 2, -1, -1, -1]
结果:
goldLabels: [12, 2, 0, 0]
predictions: [12, 3, 0, 2]
到目前为止,这是我的代码:
idcs = []
for idx, label in enumerate(goldLabels):
if label == -1:
idcs.append(idx)
goldLabels = np.delete(goldLabels, idcs)
predictions = np.delete(predictions, idcs)
有什么方法可以更有效地做到这一点?
答案 0 :(得分:0)
您可以使用numpy的功能直接使用掩码提取这些数字:
goldLabels = np.array([12, 2, 0, 0, 0, 1, 5])
predictions = np.array([12, 3, 0, 2, -1, -1, -1])
mask = predictions!=-1
predictions = predictions[mask]
goldLabels = goldLabels[mask]
print(goldLabels)
print(predictions)
输出:
[12 2 0 0]
[12 3 0 2]