我的程序中有一个函数可以计算出相关系数。它需要两个平面(一维)numpy数组并对它们执行必要的计算,以计算出两个数字列表之间的相关性(它们是float类型的货币)。只要程序运行,每个循环执行136次此函数,每个循环大约需要0.05秒。以下代码按预期计算系数:
def CC(a, b):
a = a - np.mean(a)
b = b - np.mean(b)
ab = np.sum(a*b)
asq = np.sum(a**2)
bsq = np.sum(b**2)
cc = round(ab / sqrt(asq * bsq), 5)
return cc
然而,它最终会导致内存泄漏。此内存泄漏的解决方案是将函数更改为:
def CC(a, b):
cc = round(np.sum((a - np.mean(a)) * (b - np.mean(b))) / sqrt(np.sum(a**2) * np.sum(b**2)), 5)
return cc
它可以在一行中完成所有工作,并且不会创建任何新列表,从而节省内存并避免泄漏。
然而,由于一些奇怪的原因,当使用方法2时,返回值从0.1 ish开始,然后在大约20秒的过程中趋势为0,然后从那时开始保持为0。这种情况每次都会发生。我已经尝试了方法2的替代方案,即1或2个额外的计算步骤 - 相同的结果。我已经通过消除过程隔离了所有可能的错误来源,并且它们都归结为函数本身内部发生的事情,所以它必须是一个问题。究竟是什么导致了这个?好像功能CC无视它给出的输入......如果它以某种方式设置......?
答案 0 :(得分:2)
您的代码不相同,第一个代码在第一步中重新分配a
和b
:
a = a - np.mean(a)
b = b - np.mean(b)
以及所有后续操作都使用更新后的a
和b
。然而,您的第二种方法在sqrt
- 术语中忽略了这些:
sqrt(np.sum(a**2) * np.sum(b**2))
它应该与:
相同sqrt(np.sum((a-a.mean())**2) * np.sum((b-b.mean())**2))
其他一些评论:
它可以在一行中完成所有操作,并且不会创建任何新列表,从而节省内存。
这不是真的(至少不总是如此),它仍然会产生新的数组。但是我可以看到两个你可以避免创建中间数组的地方:
np.subtract(a, a.mean(), out=a)
# instead of "a = a - np.mean(a)"
# possible also "a -= a" should work without temporary array, but I'm not 100% sure.
b = b - np.mean(b)
然而,它最终会导致内存泄漏。
我无法在第一个函数中找到任何内存泄漏的证据。
如果您关心中间阵列,您可以自己进行操作。我用numba显示它,但这可以很容易地移植到cython或类似(但我不需要添加类型注释):
import numpy as np
import numba as nb
from math import sqrt
@nb.njit
def CC_helper(a, b):
sum_ab = 0.
sum_aa = 0.
sum_bb = 0.
for idx in range(a.size):
sum_ab += a[idx] * b[idx]
sum_aa += a[idx] * a[idx]
sum_bb += b[idx] * b[idx]
return sum_ab / sqrt(sum_aa * sum_bb)
def CC1(a, b):
np.subtract(a, a.mean(), out=a)
np.subtract(b, b.mean(), out=b)
res = CC_helper(a, b)
return round(res, 5)
并将性能与两个函数进行比较:
def CC2(a, b):
a = a - np.mean(a)
b = b - np.mean(b)
ab = np.sum(a*b)
asq = np.sum(a**2)
bsq = np.sum(b**2)
cc = round(ab / sqrt(asq * bsq), 5)
return cc
def CC3(a, b):
cc = round(np.sum((a - np.mean(a)) * (b - np.mean(b))) / sqrt(np.sum((a - np.mean(a))**2) * np.sum((b - np.mean(b))**2)), 5)
return cc
确保结果相同并计时:
a = np.random.random(100000)
b = np.random.random(100000)
assert CC1(arr1, arr2) == CC2(arr1, arr2)
assert CC1(arr1, arr2) == CC3(arr1, arr2)
%timeit CC1(arr1, arr2) # 100 loops, best of 3: 2.06 ms per loop
%timeit CC2(arr1, arr2) # 100 loops, best of 3: 5.98 ms per loop
%timeit CC3(arr1, arr2) # 100 loops, best of 3: 7.98 ms per loop