我需要将numpy数组的元素替换为另一个numpy数组的最小值,以验证一个条件。请参阅以下最小示例:
arr = np.array([0, 1, 2, 3, 4])
label = np.array([0, 0, 1, 1, 2])
cond = (label == 1)
label[cond][np.argmin(arr[cond])] = 3
我希望标签现在是
array([0, 0, 3, 1, 2])
相反,我得到了
array([0, 0, 1, 1, 2])
这是已知事实numpy arrays are not updated with double slicing。
的结果无论如何,我无法弄清楚如何以简单的方式重写上述代码。任何提示?
答案 0 :(得分:2)
您正在使用链接索引触发NumPy's advanced indexing
,因此分配不会通过。要解决这个问题,一种方法是存储与掩码相对应的索引,然后使用索引。这是实施 -
idx = np.where(cond)[0]
label[idx[arr[idx].argmin()]] = 3
示例运行 -
In [51]: arr = np.array([5, 4, 5, 8, 9])
...: label = np.array([0, 0, 1, 1, 2])
...: cond = (label == 1)
...:
In [52]: idx = np.where(cond)[0]
...: label[idx[arr[idx].argmin()]] = 3
...:
In [53]: idx
Out[53]: array([2, 3])
In [54]: label
Out[54]: array([0, 0, 3, 1, 2])