替换受条件限制的numpy数组中的最小元素

时间:2015-10-27 14:43:29

标签: python arrays numpy

我需要将numpy数组的元素替换为另一个numpy数组的最小值,以验证一个条件。请参阅以下最小示例:

arr = np.array([0, 1, 2, 3, 4])
label = np.array([0, 0, 1, 1, 2])
cond = (label == 1)
label[cond][np.argmin(arr[cond])] = 3

我希望标签现在是

 array([0,  0,  3,  1, 2])

相反,我得到了

 array([0,  0,  1,  1, 2])

这是已知事实numpy arrays are not updated with double slicing

的结果

无论如何,我无法弄清楚如何以简单的方式重写上述代码。任何提示?

1 个答案:

答案 0 :(得分:2)

您正在使用链接索引触发NumPy's advanced indexing,因此分配不会通过。要解决这个问题,一种方法是存储与掩码相对应的索引,然后使用索引。这是实施 -

idx = np.where(cond)[0]
label[idx[arr[idx].argmin()]] = 3

示例运行 -

In [51]: arr = np.array([5, 4, 5, 8, 9])
    ...: label = np.array([0, 0, 1, 1, 2])
    ...: cond = (label == 1)
    ...: 

In [52]: idx = np.where(cond)[0]
    ...: label[idx[arr[idx].argmin()]] = 3
    ...: 

In [53]: idx
Out[53]: array([2, 3])

In [54]: label
Out[54]: array([0, 0, 3, 1, 2])