Numpy:需要最有效的方法来处理1D nararray中的选择元素,使用来自2D ndarray的映射,输出1D ndarray

时间:2017-10-13 04:34:31

标签: python arrays numpy time-complexity mean

首先,这不是一个家庭作业问题;它是与我的工作相关的真实问题的抽象。我非常感谢任何和所有的投入!

我需要运行类似下面的计算,按顺序运行数万次,并且计算时间会显着影响我的模拟总持续时间:

在这个抽象中:

  • 我有60,000个小部件和每个小部件类的价格数组, “widget_prices”。
  • 我有一个2D映射price_mapping,其中30,000行中的每一行 对应于购买一篮子这些小部件,以及每个小部件 60,000列对应于与。一致的窗口小部件类 指数widget_pricesBool的{​​{1}}值表示小部件不在购物篮中,false的值表示它们是。
  • 我想生成一个计算平均窗口小部件价格的数组 每个30,000个篮子(每排true

显示了数据结构的图示here

下面是我编写的一些代码,测试了我能想到的3种不同方法。第1个包括price_mapping和常规的python列表理解,第2个包含np.meannp.average。和元素明确的矩阵乘法,第3个包括np.tilenp.manp.tile

np.mean

这些是我得到的结果:

import numpy as np
import time

number_of_widgets = 60000
number_of_orders = 30000

widget_prices = np.random.uniform(0, 1, number_of_widgets)
price_mapping = np.random.randint(2, size=(number_of_orders, number_of_widgets), dtype=bool)

# method 1, using np.mean and a python list comprehension
start = time.time()
mean_price_array_1 = np.array([np.mean(widget_prices[price_mapping[i, :]]) for i in range(number_of_orders)])
end = time.time()
print('method 1 took ' + str(end - start) + ' seconds')

# method 2, using np.average, np.tile, and element-wise matrix multiplication
start = time.time()
mean_price_array_2 = np.average(np.tile(widget_prices, (number_of_orders, 1)) * price_mapping, weights=price_mapping,
                                axis=1)
end = time.time()
print('method 2 took ' + str(end - start) + ' seconds')

# method 3, using np.ma (masked array), np.tile, and np.mean
start = time.time()
mean_price_array_3 = np.ma.array(np.tile(widget_prices, (number_of_orders, 1)), mask=~price_mapping).mean(axis=1)
end = time.time()
print('method 3 took ' + str(end - start) + ' seconds')

第一个计算时间最快,但对我的需求来说仍然太慢。

有没有办法改进列表理解?

提前谢谢!!

-S

2 个答案:

答案 0 :(得分:1)

对于price_mapping作为布尔掩码,每次迭代选择widget_prices之后的元素,我们可以简单地将matrix-multiplicationnp.dot一起用于矢量化解决方案,并希望更快的方式,像这样 -

price_mapping.dot(widget_prices)/price_mapping.sum(1)

使用np.count_nonzero更快的方法是每行计算非零数。因此,另一种方式是 -

price_mapping.dot(widget_prices)/np.count_nonzero(price_mapping, axis=1)

答案 1 :(得分:0)

如果你想快速计算并且numpy没有帮助那么我会建议使用numba。

1)创建一个用于列表理解的循环intead的函数。 2)将@jit装饰器放在方法的开头,该方法将在多核PC上的parellal中运行。 3)来自numba import jit