我遇到了对cblas的调用之间的性能差异(即daxpy:执行y + = alpha * x,其中y和x是相同长度的向量,alpha是标量)和相同的操作在纯cython中进行。
这是我的基准测试,using this SO question逐行剖析cython代码:
%load_ext line_profiler
import line_profiler
%load_ext Cython
from Cython.Compiler.Options import directive_defaults
directive_defaults['linetrace'] = True
directive_defaults['binding'] = True
我想要基准的代码片段:
%%cython -a -f --compile-args=-DCYTHON_TRACE=1
import numpy as np
cimport cython
from scipy.linalg.cython_blas cimport daxpy
@cython.boundscheck(False)
def bench(double[:, ::1] G, double[::1, :] Q, double[:] x):
cdef int T = G.shape[1]
cdef int P = G.shape[0]
cdef int inc = 1 # increment for blas
cdef int p
cdef int t
cdef int i = 42
for _ in range(50): # 50 repetitions
# First version:
for p in range(P):
for t in range(T):
G[p, t] += Q[p, i] * x[t]
# second version
for p in range(P):
daxpy(&T, &Q[p, i], &x[0], &inc, &G[p, 0], &inc)
基准测试的结果是:
Timer unit: 1e-06 s
Total time: 18.543 s
File: /home/mathurin/.cache/ipython/cython/_cython_magic_32a150cd3ff68e78f896ad9eb33dda69.pyx
Function: bench at line 9
Line # Hits Time Per Hit % Time Line Contents
==============================================================
9 def bench(double[:, ::1] G, double[::1, :] Q, double[:] x):
10 1 1 1.0 0.0 cdef int T = G.shape[1]
11 1 0 0.0 0.0 cdef int P = G.shape[0]
12
13 1 1 1.0 0.0 cdef int inc = 1 # increment for blas
14 cdef int p
15 cdef int t
16 1 0 0.0 0.0 cdef int i = 42
17 # First version:
18 1 0 0.0 0.0 for _ in range(50): # 50 repetitions
19 50 22 0.4 0.0 for p in range(P):
20 9000 2512 0.3 0.0 for t in range(T):
21 63000000 **18002618** 0.3 97.1 G[p, t] += Q[p, i] * x[t]
22
23 # second version
24 50 15 0.3 0.0 for p in range(P):
25 9000 **537865** 59.8 2.9 daxpy(&T, &Q[p, i], &x[0], &inc, &G[p, 0], &inc)
第21行和第25行显示for循环比blas调用慢40倍。如果我没有弄错的话,我会以正确的顺序迭代数组(以最大限度地提高缓存命中率)。 我期待BLAS更快,但不是这么多。我有没有明显的解释?
用于获取时间的片段是:
np.random.seed(2407)
G = np.random.randn(180, 7000)
Q = np.random.randn(50, 50)
Q += Q.T
Q = np.asfortranarray(Q)
x = np.random.randn(50)
import pstats, cProfile
profile = line_profiler.LineProfiler(bench)
profile.runcall(bench, G, Q, x)
profile.print_stats()