我有以下Python(NumPy),我想重构它更干净(可能更快):
temp = max(value for (x, y), value in np.ndenumerate(cm) if x * y < 100 and (x, y) != (0, 0) and not np.isnan(value))
我认为很清楚我想做什么。总而言之,我尝试根据它的值和索引的某些条件来过滤2D数组的一些元素。
感谢任何帮助。
答案 0 :(得分:5)
import numpy as np
from numpy.random import rand, randint
cm = rand(50, 100)
cm[randint(0, 50, 4000), randint(0, 100, 4000)] = np.nan
temp1 = max(value for (x, y), value in np.ndenumerate(cm) if x * y < 100 and (x, y) != (0, 0) and not np.isnan(value))
x, y = np.indices(cm.shape)
mask = (x * y < 100) & (x + y != 0) & (~np.isnan(cm))
temp2 = np.max(cm[mask])
assert temp1 == temp2
修改强>
代表max(x+y * value)
:
np.max((x + y * cm)[mask])
或
np.max(x[mask] + y[mask] * cm[mask])