我有一块代码需要尽可能地优化,因为我必须运行几千次。
它的作用是在随机浮点数的给定数组的子列表中找到最接近的浮点数,并存储存储在该数组的另一个子列表中的相应浮点数(即:具有相同的索引)。它重复该过程,直到存储的浮点数总和达到一定限度。
这是MWE
使其更清晰:
import numpy as np
# Define array with two sub-lists.
a = [np.random.uniform(0., 100., 10000), np.random.random(10000)]
# Initialize empty final list.
b = []
# Run until the condition is met.
while (sum(b) < 10000):
# Draw random [0,1) value.
u = np.random.random()
# Find closest value in sub-list a[1].
idx = np.argmin(np.abs(u - a[1]))
# Store value located in sub-list a[0].
b.append(a[0][idx])
代码相当简单,但我还没有找到加快速度的方法。我试图调整我在前一段时间提出的类似问题中给出的伟大(并且非常快)answer,但无济于事。
答案 0 :(得分:4)
好的,这是一个略微左侧的建议。据我所知,你只是试图从a[0]
中的元素中统一采样,直到你得到一个总和超过某个限制的列表。
虽然内存方面的成本会更高,但我认为您可能会发现从a[0]
首先生成大型随机样本要快得多,然后取出cumsum并找到它首先超出限制的位置。
例如:
import numpy as np
# array of reference float values, equivalent to a[0]
refs = np.random.uniform(0, 100, 10000)
def fast_samp_1(refs, lim=10000, blocksize=10000):
# sample uniformally from refs
samp = np.random.choice(refs, size=blocksize, replace=True)
samp_sum = np.cumsum(samp)
# find where the cumsum first exceeds your limit
last = np.searchsorted(samp_sum, lim, side='right')
return samp[:last + 1]
# # if it's ok to be just under lim rather than just over then this might
# # be quicker
# return samp[samp_sum <= lim]
当然,如果blocksize
元素的样本总和是&lt; lim然后这将无法给你一个总和为> = lim的样本。您可以检查是否是这种情况,并在必要时将其附加到循环中。
def fast_samp_2(refs, lim=10000, blocksize=10000):
samp = np.random.choice(refs, size=blocksize, replace=True)
samp_sum = np.cumsum(samp)
# is the sum of our current block of samples >= lim?
while samp_sum[-1] < lim:
# if not, we'll sample another block and try again until it is
newsamp = np.random.choice(refs, size=blocksize, replace=True)
samp = np.hstack((samp, newsamp))
samp_sum = np.hstack((samp_sum, np.cumsum(newsamp) + samp_sum[-1]))
last = np.searchsorted(samp_sum, lim, side='right')
return samp[:last + 1]
请注意,连接数组非常慢,因此最好使blocksize
足够大,以便合理地确定单个块的总和将≥>到您的限制,而不会过大
我已经调整了你原来的功能,所以它的语法更接近我的。
def orig_samp(refs, lim=10000):
# Initialize empty final list.
b = []
a1 = np.random.random(10000)
# Run until the condition is met.
while (sum(b) < lim):
# Draw random [0,1) value.
u = np.random.random()
# Find closest value in sub-list a[1].
idx = np.argmin(np.abs(u - a1))
# Store value located in sub-list a[0].
b.append(refs[idx])
return b
这是一些基准数据。
%timeit orig_samp(refs, lim=10000)
# 100 loops, best of 3: 11 ms per loop
%timeit fast_samp_2(refs, lim=10000, blocksize=1000)
# 10000 loops, best of 3: 62.9 µs per loop
这要快3个数量级。你可以通过减少块大小来做得更好一点 - 你基本上希望它比你得到的阵列的长度更舒适。在这种情况下,您知道平均输出长度约为200个元素,因为0到100之间的所有实数的平均值为50,而10000/50 = 200。
获取加权样本而不是统一样本很容易 - 您只需将p=
参数传递给np.random.choice
:
def weighted_fast_samp(refs, weights=None, lim=10000, blocksize=10000):
samp = np.random.choice(refs, size=blocksize, replace=True, p=weights)
samp_sum = np.cumsum(samp)
# is the sum of our current block of samples >= lim?
while samp_sum[-1] < lim:
# if not, we'll sample another block and try again until it is
newsamp = np.random.choice(refs, size=blocksize, replace=True,
p=weights)
samp = np.hstack((samp, newsamp))
samp_sum = np.hstack((samp_sum, np.cumsum(newsamp) + samp_sum[-1]))
last = np.searchsorted(samp_sum, lim, side='right')
return samp[:last + 1]
答案 1 :(得分:0)
对参考数组进行排序。
这允许log(n)
查找而不需要浏览整个列表。 (例如,使用bisect
查找最接近的元素)
对于初学者,我反转[0]和[1]以简化排序:
a = np.sort([np.random.random(10000), np.random.uniform(0., 100., 10000)])
现在,a按[0]的顺序排序,这意味着如果您要查找与任意数字最接近的值,您可以从二等分开始:
while (sum(b) < 10000):
# Draw random [0,1) value.
u = np.random.random()
# Find closest value in sub-list a[0].
idx = bisect.bisect(a[0], u)
# now, idx can either be idx or idx-1
if idx is not 0 and np.abs(a[0][idx] - u) > np.abs(a[0][idx - 1] - u):
idx = idx - 1
# Store value located in sub-list a[1].
b.append(a[1][idx])
答案 2 :(得分:0)
一个明显的优化 - 不要在每次迭代时重新计算总和,累积它
b_sum = 0
while b_sum<10000:
....
idx = np.argmin(np.abs(u - a[1]))
add_val = a[0][idx]
b.append(add_val)
b_sum += add_val
修改强>
我认为可以通过在循环之前预先引用子列表来实现一些小的改进(如果您愿意,请查看它)
a_0 = a[0]
a_1 = a[1]
...
while ...:
....
idx = np.argmin(np.abs(u - a_1))
b.append(a_0[idx])
它可能会节省一些运行时间 - 虽然我不相信它会那么重要。
答案 3 :(得分:0)
用cython编写。对于高迭代操作,这将为您提供更多帮助。