如何在cython中并行化numpy操作

时间:2016-03-26 13:13:41

标签: python performance numpy parallel-processing cython

我正在尝试并行化以下代码,其中包括许多numpy数组操作

    #fft_fit.pyx
    import cython
    import numpy as np
    cimport numpy as np
    from cython.parallel cimport prange
    from libc.stdlib cimport malloc, free

    dat1 = np.genfromtxt('/home/bagchilab/Sumanta_files/fourier_ecology_sample_data_set.csv',delimiter=',')
    dat = np.delete(dat1, 0, 0)
    yr = np.unique(dat[:,0])
    fit_dat = np.empty([1,2])


    def fft_fit_yr(np.ndarray[double, ndim=1] yr, np.ndarray[double, ndim=2] dat, int yr_idx, int pix_idx):
        cdef np.ndarray[double, ndim=2] yr_dat1 
        cdef np.ndarray[double, ndim=2] yr_dat
        cdef np.ndarray[double, ndim=2] fft_dat
        cdef np.ndarray[double, ndim=2] fft_imp_dat
        cdef int len_yr = len(yr)
        for i in prange(len_yr ,nogil=True):
            with gil:

                yr_dat1 = dat[dat[:,yr_idx]==yr[i]]
                yr_dat = yr_dat1[~np.isnan(yr_dat1).any(axis=1)]
                print "index" ,i
                y_fft = np.fft.fft(yr_dat[:,pix_idx])
                y_fft_abs = np.abs(y_fft)
                y_fft_freq = np.fft.fftfreq(len(y_fft), 1)
                x_fft = range(len(y_fft))
                fft_dat = np.column_stack((y_fft, y_fft_abs))
                cut_off_freq = np.percentile(y_fft_abs, 25)
                imp_freq =  np.array(y_fft_abs[y_fft_abs > cut_off_freq])
                fft_imp_dat = np.empty((1,2))
        for j in range(len(imp_freq)):
                    freq_dat = fft_dat[fft_dat[:, 1]==imp_freq[j]]
                    fft_imp_dat  = np.vstack((fft_imp_dat , freq_dat[0,:]))       
                fft_imp_dat = np.delete(fft_imp_dat, 0, 0)
                fit_dat1 = np.fft.ifft(fft_imp_dat[:,0])
                fit_dat2 = np.column_stack((fit_dat1.real, [yr[i]] * len(fit_dat1)))
                fit_dat = np.concatenate((fit_dat, fit_dat2), axis = 0) 

我已将以下代码用于setup.py

    ####setup.py
    from distutils.core import setup
    from distutils.extension import Extension
    from Cython.Distutils import build_ext

    setup(
cmdclass = {'build_ext': build_ext},
ext_modules = [Extension("fft_fit_yr", ["fft_fit.pyx"])]
    extra_compile_args=['-fopenmp'],
    extra_link_args=['-fopenmp'])]
    )

但是当我在cython中编译fft_fit.pyx时出现以下错误:

    for i in prange(len_yr ,nogil=True):
    target may not be a Python object as we don't have the GIL

使用prange功能时,请告诉我出错的地方。 感谢。

1 个答案:

答案 0 :(得分:3)

你不能(至少不使用Cython)。

Numpy函数对Python对象进行操作,因此需要GIL,这可以防止多个本机线程并行执行。如果使用cython -a编译代码,您将获得一个带注释的HTML文件,该文件显示正在进行Python C-API调用的位置(因此无法释放GIL)。

如果您的代码中存在特定的瓶颈,使用矢量化无法轻松加速,则Cython非常有用。如果你的代码已经花费了大部分时间在numpy函数调用中,那么从Cython中调用那些完全相同的函数不会导致任何显着的性能提升。为了看到明显的差异,您需要将部分或全部数组操作写为显式for循环。但是,它看起来好像可以对您的代码进行更简单的优化。

我建议您执行以下操作:

  1. 描述原始Python代码(例如,使用line_profiler)以查看瓶颈所在。
  2. 将注意力集中在加速单线程版本中的这些瓶颈。如果你需要帮助,你应该问一个单独的问题。
  3. 如果优化的单线程版本仍然太慢而无法满足您的需求,请使用joblibmultiprocessing对其进行并行化。并行化通常是 last 工具,一旦您已经尝试过其他您能想到的其他内容,就可以使用它。