我有一个相当简单的代码块,我想提高其性能。它由for
块组成,该块使用np.where()查找数组中整数的索引。
下面的代码有效,但是我觉得使用for
将元素添加到空列表并不是解决此问题的最佳方法。
MCMC使用此块,因此它执行了数百万次。从小的改进变成大的改进。可以提高效率吗?
import numpy as np
N = 20
# Integers from 1 to N
ran_indexes = np.random.randint(1, N, 1000)
# Number of integers to remove
rm_number = np.random.randint(0, 100, N)
# Better performance for this block?
# For each integer from 1 to N, keep only 'd' indexes of 'ran_indexes' that
# contain that integer, where 'd' is the ith element in 'rm_number'
new_indexes = []
for i, d in enumerate(rm_number):
new_indexes += list(np.where(ran_indexes == i + 1)[0][:d])
答案 0 :(得分:2)
列表连接的执行速度很慢+=
,因为它们每次都需要一个新列表。在迭代构建数组时,更经常地使用列表附加,它是就地的,并且每次仅在列表上添加元素。
In [45]:
...: new_indexes = []
...: for i, d in enumerate(rm_number):
...: new_indexes.append(np.where(ran_indexes == i + 1)[0][:d])
...:
In [46]: new_indexes
Out[46]:
[array([ 5, 96, 143, 150, 154, 175]),
array([ 14, 22, 26, 28, 32, 38, 46, 54, 70, 205, 218, 242, 248,
254, 271, 318, 344, 352, 357, 393, 419, 437, 448, 472, 473, 503,
521, 548, 558, 629, 631, 654, 661, 685, 699, 743, 755]),
array([ 24, 34, 72, 97, 120, 140, 173, 181, 193, 199, 200, 225, 239,
251, 265, 296, 350, 386, 411, 422, 465, 476, 506, 533, 609, 628,
680, 694, 713, 759]),
....
采用这种结构,每个数组(where
结果)的长度都不同,上限为rm_number
:
In [89]: [len(i) for i in new_indexes]-rm_number
Out[89]:
array([ 0, 0, 0, 0, 0, 0, 0, -2, -24, -40, 0, -3, -40,
0, -15, -5, 0, 0, 0, -96])
类似的可变长度数组/列表很好地表明您不能执行超快速的“向量化”(整个数组)操作,至少不是没有足够的聪明。
我们可以使用以下代码获取您的代码生成的平面列表:
In [50]: np.concatenate(new_indexes).shape
Out[50]: (626,)
一些时间:
In [53]: %%timeit
...: new_indexes = []
...: for i, d in enumerate(rm_number):
...: new_indexes += list(np.where(ran_indexes == i + 1)[0][:d])
...:
320 µs ± 7.93 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [54]:
In [54]: %%timeit
...: new_indexes = []
...: for i, d in enumerate(rm_number):
...: new_indexes.append(np.where(ran_indexes == i + 1)[0][:d])
...:
184 µs ± 268 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [55]:
In [55]: %%timeit
...: new_indexes = []
...: for i, d in enumerate(rm_number):
...: new_indexes.append(np.where(ran_indexes == i + 1)[0][:d])
...: new_indexes=np.concatenate(new_indexes)
...:
...:
193 µs ± 622 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [79]: timeit f2() # Lukas
291 µs ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
===
temp = ran_indexes[:,None]==np.arange(1,21)
查找所有匹配项,并且np.where(temp)[0]
是索引。但这并不适用您的rm_number
界限。
np.where(temp.T)[1] # without the `rm_number` truncation
np.where(temp[:,i])[0][:d]
答案 1 :(得分:1)
为您的new_indexes
预先分配空间,并在您的测试运行中不断地将其追加到现有列表中以20-30%的速度填充,请参见下面的实现
def f1():
new_indexes = []
for i, d in enumerate(rm_number):
new_indexes += list(np.where(ran_indexes == i + 1)[0][:d])
return new_indexes
def f2():
new_indexes = np.zeros(sum(rm_number))
ind = 0
for i, d in enumerate(rm_number):
tmp = np.where(ran_indexes == i + 1)[0][:d]
new_indexes[ind:ind+tmp.shape[0]] = tmp
ind += tmp.shape[0]
return list(new_indexes[0:ind])
In [144]: %timeit f1
33.5 ns ± 1.71 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
In [145]: %timeit f2
23.6 ns ± 0.273 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
In [146]: %timeit f1
35.2 ns ± 3.74 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
In [147]: %timeit f2
24.5 ns ± 1.47 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
请注意:两个for循环的最后一次迭代都保持不变,因此rm_numbers
中的最后一个数字永远不会用于任何有生产力的事情。 ran_indexes中的最大数量为19,在上一次迭代中,您正在检查ran_indexes == 19 + 1
,该值始终为零。我不确定这是否有意,我想您想修改ran_indexes的定义以将N+1
作为上限(假设上限是互斥的)。
如果确实应该将19设为最高随机数,那么您应该可以跳过最后一个循环来节省几纳秒的时间
答案 2 :(得分:1)
我想到了两种可能的方法-删除循环(f_2
-快30%)或使用numba(f_3
-快6倍)。 Numba还需要一些不同的实现方法-更少的python,更少的复制数据工作,更多的numpy和更多的读取数据工作。不知道您是否可以使用numba,但值得一试。当然,Cython可以替代numba。但是,Cython不仅需要用numba包装函数,还需要更多的重构。
import numba as nb
import numpy as np
def f_1(ran_indexes, rm_number):
new_indexes = []
for idx, qty in enumerate(rm_number):
new_indexes += list(np.where(ran_indexes == idx + 1)[0][:qty])
return new_indexes
def f_2(ran_indexes, rm_number):
return np.hstack([np.where(ran_indexes == idx + 1)[0][:qty] for idx, qty in enumerate(rm_number)])
@nb.njit
def f_3(ran_indexes, rm_number):
ans = np.zeros(rm_number.sum(), dtype=np.int64)
count = 0
for idx in range(rm_number.shape[0]):
count_2 = 0
for idx_2 in range(ran_indexes.shape[0]):
if count_2 == rm_number[idx]:
break
if ran_indexes[idx_2] == idx + 1:
ans[count + count_2] = idx_2
count_2 += 1
count += count_2
return ans[:count]
if __name__ == '__main__':
N = 20
ran_indexes_ = np.random.randint(1, N, 1000)
rm_number_ = np.random.randint(0, 100, N - 1)
ans_1 = f_1(ran_indexes_, rm_number_)
ans_2 = f_2(ran_indexes_, rm_number_)
ans_3 = f_3(ran_indexes_, rm_number_)
# check results
print(sum(ans_1), sum(ans_2), sum(ans_3))
print(len(ans_1), len(ans_2), len(ans_3))
结果:
%timeit f_1(ran_indexes_, rm_number_)
111 µs ± 279 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit f_2(ran_indexes_, rm_number_)
77 µs ± 118 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit f_3(ran_indexes_, rm_number_)
17 µs ± 6.01 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)