鉴于S
是一个n x m
矩阵,作为一个numpy数组,我想在成对的f
上调用函数(S[i], S[j])
来计算感兴趣的特定值,并存储尺寸为n x n
的矩阵中。在我的特殊情况下,函数f
是可交换的,因此f(x,y) = f(y,x)
。
考虑到所有这些,我想知道我是否可以采取任何措施尽可能快地加快速度,n
可能会很大。
当我对函数f计时时,大约是几微秒,这是预期的。这是一个非常简单的计算。下面我将与max()
和sum()
比较以供参考。
In [19]: %timeit sum(s[11])
4.68 µs ± 56.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [20]: %timeit max(s[11])
3.61 µs ± 64.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [21]: %timeit f(s[11], s[21], 50, 10, 1e-5)
1.23 µs ± 7.25 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [22]: %timeit f(s[121], s[321], 50, 10, 1e-5)
1.26 µs ± 31.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
但是,当我为500x50样本数据计时总处理时间(得出500 x 500/2 = 125K比较)时,总时间就大大增加了(几分钟)。我本来希望像0.2-0.3秒(1.25E5 * 2E-6秒/计算)。
In [12]: @jit
...: def testf(s, n, m, p):
...: tol = 1e-5
...: sim = np.zeros((n,n))
...: for i in range(n):
...: for j in range(n):
...: if i > j:
...: delta = p[i] - p[j]
...: if delta < 0:
...: res = f(s[i], s[j], m, abs(delta), tol) # <-- max(s[i])
...: else:
...: res = f(s[j], s[i], m, delta, tol) # <-- sum(s[j])
...: sim[i][j] = res
...: sim[j][i] = res
...: return sim
在上面的代码中,我更改了将res
分配给max()
和sum()
(注释掉的部分)以进行测试的行,并且代码执行速度快了大约100倍,即使函数本身比我的函数f()
哪个带给我我的问题:
我可以避免双循环来加快速度吗?理想情况下,我希望能够对n = 1E5大小的矩阵运行此命令。 (评论:由于max和sum函数的运行速度相当快,我猜想for循环不是这里的瓶颈,但是还是知道是否有更好的方法还是很不错的。)
如果不是double for循环,是什么导致我的函数严重减速?
编辑
通过一些评论询问了函数f的细节。它在两个数组上进行迭代,并检查两个数组中“足够接近”的值的数量。我删除了注释并更改了一些变量名,但逻辑如下所示。有趣的是,math.isclose(x,y,rel_tol)
等同于我下面的if语句,这可能会使代码的运行速度大大降低,这可能是由于库调用造成的?
from numba import jit
@jit
def f(arr1, arr2, n, d, rel_tol):
counter = 0
i,j,k = 0,0,0
while (i < n and j < n and k < n):
val = arr1[j] + d
if abs(arr1[i] - arr2[k]) < rel_tol * max(arr1[i], arr2[k]):
counter += 1
i += 1
k += 1
elif abs(val - arr2[k]) < rel_tol * max(val, arr2[k]):
counter += 1
j += 1
k += 1
else:
# incremenet the index corresponding to the lightest
if arr1[i] <= arr2[k] and arr1[i] <= val:
if i < n:
i += 1
elif val <= arr1[i] and val <= arr2[k]:
if j < n:
j += 1
else:
k += 1
return counter