我正在使用numpy
做一些矢量代数,而我的算法的挂钟性能似乎很奇怪。该程序大致执行以下操作:
Y
(KxD),X
(NxD),T
(KxN)Y
的每一行:Y[i]
的每一行中减去X
(通过广播),T
中。但是,根据我进行广播的方式,计算速度有很大不同。考虑代码:
import numpy as np
from time import perf_counter
D = 128
N = 3000
K = 500
X = np.random.rand(N, D)
Y = np.random.rand(K, D)
T = np.zeros((K, N))
if True: # negate to enable the second loop
time = 0.0
for i in range(100):
start = perf_counter()
for i in range(K):
T[i] = np.sqrt(np.sum(
np.square(
X - Y[i] # this has dimensions NxD
),
axis=1
))
time += perf_counter() - start
print("Broadcast in line: {:.3f} s".format(time / 100))
exit()
if True:
time = 0.0
for i in range(100):
start = perf_counter()
for i in range(K):
diff = X - Y[i]
T[i] = np.sqrt(np.sum(
np.square(
diff
),
axis=1
))
time += perf_counter() - start
print("Broadcast out: {:.3f} s".format(time / 100))
exit()
每个循环的时间分别进行测量,并平均执行100次。结果:
Broadcast in line: 1.504 s
Broadcast out: 0.438 s
唯一的区别是,第一个循环中的广播和减法是在线完成的,而第二种方法中,我在进行任何矢量化运算之前都进行了广播和减法。为什么会产生如此不同?
我的系统配置:
np.__config__.show()
告诉我的话就这么多)PS:是的,我知道这可以进一步优化,但是现在我想了解这里发生的情况。
答案 0 :(得分:1)
我还添加了一个优化的解决方案,以查看实际的计算需要花费多长时间,而不会占用大量的内存分配和释放。
功能
import numpy as np
import numba as nb
def func_1(X,Y,T):
for i in range(K):
T[i] = np.sqrt(np.sum(np.square(X - Y[i]),axis=1))
return T
def func_2(X,Y,T):
for i in range(K):
diff = X - Y[i]
T[i] = np.sqrt(np.sum(np.square(diff),axis=1))
return T
@nb.njit(fastmath=True,parallel=True)
def func_3(X,Y,T):
for i in nb.prange(Y.shape[0]):
for j in range(X.shape[0]):
diff_sq_sum=0.
for k in range(X.shape[1]):
diff_sq_sum+= (X[j,k] - Y[i,k])**2
T[i,j]=np.sqrt(diff_sq_sum)
return T
时间
我在Jupyter笔记本中进行了所有计时,并观察到一个非常奇怪的行为。以下代码在一个单元格中。我也尝试多次调用timit,但是在第一次执行单元格时,它什么都没有改变。
单元的首次执行
D = 128
N = 3000
K = 500
X = np.random.rand(N, D)
Y = np.random.rand(K, D)
T = np.zeros((K, N))
#You can do it more often it would not change anything
%timeit func_1(X,Y,T)
%timeit func_1(X,Y,T)
#You can do it more often it would not change anything
%timeit func_2(X,Y,T)
%timeit func_2(X,Y,T)
###Avoid measuring compilation overhead###
%timeit func_3(X,Y,T)
##########################################
%timeit func_3(X,Y,T)
774 ms ± 6.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
768 ms ± 2.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
494 ms ± 2.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
494 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
10.7 ms ± 1.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
6.74 ms ± 39.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
第二次执行
345 ms ± 16.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
337 ms ± 3.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
322 ms ± 834 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
323 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
6.93 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.9 ms ± 87.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)