Cython 没有加速

时间:2021-07-30 16:52:21

标签: python numpy cython

与基本 Python 代码相比,我一直在尝试测试使用 Cython 的加速潜力。为此,我编写了两个脚本“linearAdvec_mat.py”和“linearAdvec_mat.pyx”,如下所示:

linearAdvec_mat.py:

import numpy as np

def Adv_mat(N):
    A = np.zeros((N, N));
    for i in range(N):
        if i == 0:
            A[i, N - 1] = -1.0;
            A[i, i] = 0.0;
            A[i, i + 1] = 1.0;
        elif i == N - 1:
            A[i, i - 1] = -1.0;
            A[i, i] = 0.0;
            A[i, 0] = 1.0;
        else:
            A[i, i - 1] = -1.0;
            A[i, i] = 0.0;
            A[i, i + 1] = 1.0;
    return A;

def Diff_mat(N):
    D = np.zeros((N, N));
    for i in range(N):
        if i == 0:
            D[i, N - 1] = 1.0;
            D[i, i] = -2.0;
            D[i, i + 1] = 1.0;
        elif i == N - 1:
            D[i, i - 1] = 1.0;
            D[i, i] = -2.0;
            D[i, 0] = 1.0;
        else:
            D[i, i - 1] = 1.0;
            D[i, i] = -2.0;
            D[i, i + 1] = 1.0;
    return D;

def Compute_eigVals(N, alpha, kdt):
    A = Adv_mat(N);
    D = Diff_mat(N);
    ADt = A*(-alpha/2.0) + D*kdt;
    ldt = np.zeros(N, 'complex');
    beta = np.zeros(N);
    for m in range(N):
        beta[m] = 2*np.pi*m/N;
        if beta[m] > np.pi:
            beta[m] = 2*np.pi - beta[m];
        for j in range(N):
            ldt[m] += ADt[0, j]*np.exp(1j*2.0*np.pi*j*m/N);
    return ldt;

和linearAdvec_mat.pyx:

import numpy as np
cimport numpy as np

DTYPE = np.float64;
DTYPE_c = np.complex128;
ctypedef np.float64_t DTYPE_t;

cdef np.ndarray[DTYPE_t, ndim = 2] Adv_mat(int N):
    cdef np.ndarray[DTYPE_t, ndim = 2] A = np.zeros((N, N), dtype = DTYPE);
    cdef int i;
    for i in range(N):
        if i == 0:
            A[i, N - 1] = -1.0;
            A[i, i] = 0.0;
            A[i, i + 1] = 1.0;
        elif i == N - 1:
            A[i, i - 1] = -1.0;
            A[i, i] = 0.0;
            A[i, 0] = 1.0;
        else:
            A[i, i - 1] = -1.0;
            A[i, i] = 0.0;
            A[i, i + 1] = 1.0;
    return A;

cdef np.ndarray[DTYPE_t, ndim = 2] Diff_mat(int N):
    cdef np.ndarray[DTYPE_t, ndim = 2] D = np.zeros((N, N), dtype = DTYPE);
    cdef int i;
    for i in range(N):
        if i == 0:
            D[i, N - 1] = 1.0;
            D[i, i] = -2.0;
            D[i, i + 1] = 1.0;
        elif i == N - 1:
            D[i, i - 1] = 1.0;
            D[i, i] = -2.0;
            D[i, 0] = 1.0;
        else:
            D[i, i - 1] = 1.0;
            D[i, i] = -2.0;
            D[i, i + 1] = 1.0;
    return D;

cpdef np.ndarray[np.complex128_t, ndim = 1] Compute_eigVals(int N, double alpha, double kdt):
    cdef np.ndarray[DTYPE_t, ndim = 2] A = Adv_mat(N);
    cdef np.ndarray[DTYPE_t, ndim = 2] D = Diff_mat(N);
    cdef np.ndarray[np.complex128_t, ndim = 2] ADt = A*(-alpha/2.0) + D*kdt + 0j;
    cdef np.ndarray[np.complex128_t, ndim = 1] ldt = np.zeros(N, dtype = DTYPE_c);
    cdef np.ndarray[DTYPE_t, ndim = 1] beta = np.zeros(N, dtype = DTYPE);
    cdef int m, k;
    for m in range(N):
        beta[m] = 2*np.pi*m/N;
        if beta[m] > np.pi:
            beta[m] = 2*np.pi - beta[m];
        for k in range(N):
            ldt[m] = ldt[m] + ADt[0, k]*np.exp(1j*2.0*np.pi*k*m/N);
    return ldt;

当我从基础 python 和编译后的 .so 文件调用“Compute_eigVals”函数时,如下所示,我没有从 cython 脚本中获得任何显着的加速。

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
from libs.linearAdvec_mat import Compute_eigVals as Compute_eigVals_cy
from linearAdvec_mat import Compute_eigVals as Compute_eigVals_py
import time
#%% ------------------ Inputs ---------------------
N = 1000;
alpha = 0.8;
kdt = 0.05;

st = time.time();
eigs = Compute_eigVals_cy(N, alpha, kdt);
t_cy = time.time() - st;
print('Cython time : %0.8fs\n'%(t_cy));

st = time.time();
eigs = Compute_eigVals_py(N, alpha, kdt);
t_py = time.time() - st;
print('Python time : %0.8fs\n'%(t_py));
print('Cython is %0.5f times faster'%(t_py/t_cy));

我试图通过运行来检查 python 交互的数量

cython -a linearAdvec_mat.pyx

在终端中,但我无法从中解决任何问题。有人可以提供一些关于为什么我在使用 cython 时没有获得大量加速的见解吗?我的第一个猜测是,我的基础 Python 脚本严重依赖于 numpy,它已经处于优化状态,但我完全确定并渴望弄清楚实际发生了什么。

1 个答案:

答案 0 :(得分:6)

Cython 解决方案:

让我们为您的 Python 函数计时作为基准参考:

In [3]: %timeit Compute_eigVals(N, alpha, kdt)
3.85 s ± 22.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

通过在 jupyter notebook 中分析您的 Python 代码

In [4]: %lprun -f Compute_eigVals Compute_eigVals(N, alpha, kdt)
Timer unit: 1e-06 s

Total time: 4.35475 s
File: <ipython-input-1-61dba133ade4>
Function: Compute_eigVals at line 37

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    37                                           def Compute_eigVals(N, alpha, kdt):
    38         1       2491.0   2491.0      0.1      A = Adv_mat(N);
    39         1       2295.0   2295.0      0.1      D = Diff_mat(N);
    40         1       8582.0   8582.0      0.2      ADt = A*(-alpha/2.0) + D*kdt;
    41         1         11.0     11.0      0.0      ldt = np.zeros(N, 'complex');
    42         1          2.0      2.0      0.0      beta = np.zeros(N);
    43      1001        357.0      0.4      0.0      for m in range(N):
    44      1000        713.0      0.7      0.0          beta[m] = 2*np.pi*m/N;
    45      1000        720.0      0.7      0.0          if beta[m] > np.pi:
    46       499        356.0      0.7      0.0              beta[m] = 2*np.pi - beta[m];
    47   1001000     390717.0      0.4      9.0          for j in range(N):
    48   1000000    3948510.0      3.9     90.7              ldt[m] += ADt[0, j]*np.exp(1j*2.0*np.pi*j*m/N);
    49         1          1.0      1.0      0.0      return ldt;

我们可以观察到时间关键部分是最里面的循环。那么让我们来看看你的 cython 代码:

enter image description here

以下是减少python开销的几个关键点:

  • 访问常量 np.pi 有明显的 Python 开销。相反,您可以在 pi 中使用 C 常量 libc.math。此外,您可以缓存 2.0*pi1j*2.0*pi 的结果,因为您多次使用它们。
  • 同样,函数 np.exp 也有 python 开销,并且为标量参数调用它并不能证明调用 python 函数的开销是合理的。相反,您可以使用 C cexp 函数。
  • 最后,您可以使用 Cython Compiler directives 进一步加速您的代码。在这里,我们启用 C 整数除法 (cdivision)、禁用索引检查 (boundscheck) 和禁用负索引 (wraparound)

在代码中:

cimport cython

from libc.math cimport pi
cdef extern from "complex.h":
    double complex cexp(double complex)

# Adv_mat and Diff_mat are the same as above

@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef np.ndarray[np.complex128_t, ndim = 1] Compute_eigVals(int N, double alpha, double kdt):
    cdef np.ndarray[DTYPE_t, ndim = 2] A = Adv_mat(N)
    cdef np.ndarray[DTYPE_t, ndim = 2] D = Diff_mat(N)
    cdef np.ndarray[np.complex128_t, ndim = 2] ADt = A*(-alpha/2.0) + D*kdt + 0j
    cdef np.ndarray[np.complex128_t, ndim = 1] ldt = np.zeros(N, dtype = DTYPE_c)
    cdef np.ndarray[DTYPE_t, ndim = 1] beta = np.zeros(N, dtype = DTYPE)
    cdef int m, k
    cdef double two_pi = 2*pi
    cdef double complex factor = 1j*2.0*pi+0
    for m in range(N):
        beta[m] = two_pi*m / N;
        if beta[m] > pi:
            beta[m] = two_pi - beta[m];
        for k in range(N):
            ldt[m] = ldt[m] + ADt[0, k]*cexp(factor*k*m / N);
    return ldt;

这消除了循环内的所有 python 交互。在我的机器上计时会给出:

In [6]: %timeit Compute_eigVals(N, alpha, kdt)
45.8 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

改进的 Python 版本:

还要注意,Cython 没有真正的需要,因为您可以用矢量化的 numpy 操作替换您的 python 循环:

def Compute_eigVals2(N, alpha, kdt):
    A = Adv_mat(N);
    D = Diff_mat(N);
    ADt = A*(-alpha/2.0) + D*kdt;
    beta = 2*np.pi*np.arange(N)/N
    beta[beta > np.pi] = 2*np.pi - beta[beta > np.pi]
    JM = np.arange(N) * np.arange(N)[:, None]
    ldt = np.sum(ADt[0, :] * np.exp(1j*2.0*np.pi*JM/N), axis=-1)
    return ldt
In [7]: %timeit Compute_eigVals2(N, alpha, kdt)
35.8 ms ± 655 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
相关问题