如何优化numpy循环,该循环对数组的值求和,该数组由另一个数组索引,其中值等于循环索引

时间:2017-10-18 12:55:56

标签: python performance numpy for-loop

我有一段在应用程序运行期间多次调用的代码。 它需要一组代表值的数字(value_array)。 这些应该在zone中汇总,这些区域在zone_array中定义。 zone_ids表示zone_array中所有可能区域的列表。

它基本上是这样的:我有一个人口栅格地图,我想知道有多少人住在区域地图的每个区域。

代码:

values = np.zeros(len(zone_ids))
for i in zone_ids:
    values[i] = round(np.nansum(value_array[zone_array == i]), 2)
return values

罪魁祸首似乎是for循环,但我没有找到消除它的方法,并得到相同的结果。

我用bincount尝试了但是我没有成功。 使用numba jit也没有效果。

我想远离cython,因为此代码将用于没有cython支持的Qgis插件。

测试代码:

import numpy as np


def fill_values(zone_array, value_array, zone_ids):
    values = np.zeros(len(zone_ids))
    for i in zone_ids:
        values[i] = round(np.nansum(value_array[zone_array == i]), 2)
    return values


def run():
    # 300 different zones
    zone_ids = range(300)
    # zone map with 300 zones
    zone_array = (np.random.rand(2000, 2000) * 300).astype(int)
    # value map from which we want the sum of values per zone (real map can have NaN values)
    value_array = (np.random.rand(2000, 2000) * 10.)
    value_array[5, 5] = np.NAN
    fill_values(zone_array, value_array, zone_ids)


if __name__ == '__main__':
    run()

每回路1.92 s±17.5 ms(平均值±标准偏差,7次运行,每次1次循环)

按照Divakar的建议实施bincount:

每循环203 ms±15.2 ms(平均值±标准偏差,7次运行,每次循环1次)

1 个答案:

答案 0 :(得分:1)

直接使用bincount,您可以在摘要中使用NaNs。因此,您只需将NaNs替换为zeros并使用bincount即可。这应该快得多,是一个矢量化解决方案。

因此,实施将是 -

val_nonan = np.where(np.isnan(value_array), 0, value_array)
out = np.round(np.bincount(zone_array.ravel(), val_nonan.ravel()),2)