Cython代码比相应的NumPy版本慢3倍

时间:2015-03-30 17:58:01

标签: python performance numpy cython

我目前正在撰写关于使用粒子滤波器滤除随机波动率模型中潜在状态的论文。为了改善过滤结果,我将选项价格作为观察过程添加。这意味着对于给定的时间序列,我必须计算每个时间步的期权价格 - "正常"时间序列是100-200点。

在没有深入算法的情况下,我遇到了严重的性能问题。最后一个for循环遍历我使用的所有粒子,大约是1000(由M确定)。仅运行一个粒子的代码需要0.25秒 - 这意味着每个时间步使用1,000个粒子运行大约需要4分钟(这是不可行的)。

from __future__ import division
import numpy as np
import numexpr as ne
from fftInC import fft
import time
import math
import pyfftw

def HestonCallPrice(M, N, S, V, t, T, strikes, r, param, b, NFFT, inp, v, alphaC, eta, k, weights):
    """
    This will be the pricing function for the European call option. Since we found the 
    quadrature procedure to be too slow we shall move on to use FFT instead. 

    So, we begin defining all of the constants etc. 
    """  

    vT, weightsT, inpJ, vJT = v.T, weights.T, inp * 1j, v.T * 1j

    p1, p2, p3_2, p3, p4 = param[1,:], param[2,:], param[3,:], np.sqrt(param[3,:]), param[4,:]

    """
    Next we move on to the calculations. These have been found to be rather fast, and hence do not
    need any Cythonization.
    """

    gamma = p3_2 / 2

    beta = ne.evaluate("p1 - p4 * p3 * 1j * inp")

    alpha = ne.evaluate("(-inp**2 - inpJ)/2")

    d = ne.evaluate("sqrt(beta**2 - 4 * alpha * gamma)")

    r_pos, r_neg = ne.evaluate("(beta + d)/(2 * gamma)"), ne.evaluate("(beta - d)/(2 * gamma)")

    g, inpJT = ne.evaluate("r_neg / r_pos"), inpJ.T

    D = ne.evaluate("r_neg * (1 - exp( -d * (T - t) ) ) / (1 - g * exp( -d * (T - t) ) )" )

    C = ne.evaluate("p1 * (r_neg*(T - t) - 2 / p3_2 * log( (1 - g*exp(-d*(T - t)))/(1 - g) ) )")

    A = 1j * inp.T * (math.log(S) + r * (T - t))

    C_tmp = (C * p2).T

    """
    The matrices and vectors that are sent into the Cython version of the code are

        A =  (1, 2048)

        C_tmp = (4, 2048)

        D.T = (4, 2048)

        V = (4, 1000)

        vJT[0, :] = (2048,)

        k[:, 0] = (2048,)

        weights.T[0, :] = (2048,)

    This is now where we call the Cython script.
    """

    start = time.time()

    prices = fft(A, float(r), float(t), float(T), C_tmp, D.T, V, float(alphaC), vJT[0, :], k[:, 0],

                 float(b), strikes, float(eta), weights.T[0, :])   

    print 'Cythonized version: ', time.time() - start, ' seconds'

    """
    The below code is the original code which has been "cythonized". 
    """
    start = time.time()

    outPrices = np.empty( (M, N) )

    prices = np.empty( (M * N, len(strikes)) )  

    """
    Regularly I use pyFFTW since it's a bit faster, but I couldn't figure out how to use the C 
    version of this, so to be fair when comparing speeds I disable pyFFTW. However, turning this on
    using the below settings it's 20-30% faster. 
    """
#    fftIn = pyfftw.n_byte_align_empty((N, NFFT), 16, 'complex128')
#    
#    fftOut = fftIn.copy()
#    
#    fft_object = pyfftw.FFTW(fftIn, fftOut, nthreads=8)

    for j in range( len(strikes) ):

        position = (np.log(strikes[j]) + b) / ( 2 * b / NFFT)

        x_1 = np.exp( k[ int(math.floor(position)) ] )

        x_2 = np.exp( k[ int(math.ceil(position)) ] )

        for m in range(M):

            C_m, D_m, V_m = C_tmp[m, :], D[:, m].T, V[m, :][:, np.newaxis]

            F_cT =  ne.evaluate("exp( -r*(T - t) ) * exp(C_m + D_m * V_m + A)  / \
                     ( (alphaC + vJT) * (alphaC + 1 + vJT) )") 

            toFFT = ne.evaluate("exp( b * vJT )  * F_cT * eta / 3 * weightsT")

            price = np.exp( -alphaC * k.T ) / math.pi * np.real ( np.fft.fft(toFFT) )

            y_1 = price[ :, int(math.floor(position)) ]

            y_2 = price[ :, int(math.ceil(position)) ]

            dydx = (y_2 - y_1)/(x_2 - x_1)  

            outPrices[m, :] = dydx * (strikes[j] - x_1) + y_1

        prices[:, j] = outPrices.reshape(M * N)

    print 'Non-cythonized version: ', time.time() - start, ' seconds'

    return prices

" ------ Defining constants etc, nothing to say really -----  "    
M, N, S, t, T, r, NFFT, alphaC = 1, 1000, 1000, 0, 1, 0, 2048, 1.5

strikes = np.array([900, 1100])

c, V = 600, np.random.normal(loc=0.2, scale=0.05, size=(M, N))

param = np.repeat(np.array([0.05, 0.5, 0.15, 0.15**2, 0]), M).reshape((5, M))

eta = c / NFFT

b = np.pi / eta

j = np.arange(1, NFFT+1)[:, np.newaxis]

v, k = eta * (j - 1), -b + 2 * b/ NFFT*(j - 1)

inp = v - (alphaC + 1)*1j

weights = 3 + (-1)**j - np.array([1] + [0]*(NFFT-1))[:, np.newaxis]

" ------------------------------------------------------------- "

HestonCallPrice(M, N, S, V, t, T, strikes, r, param, b, NFFT, inp, v, alphaC, eta, k, weights)

我发现瓶颈是最后一个for循环。我得到了一个提示,在Cython中重写for循环,见下文

" --------------------------------- C IMPORTED PACKAGES ------------------------------------------ "
from __future__ import division

import cython
cimport cython

import math

cimport numpy as np
import numpy as np

import pyfftw
" ------------------------------------------------------------------------------------------------ "

"""
I heard that the boundscheck and wraparound functions could improve the performance, but I didn't 
notice any performance gain whatsoever.
"""
@cython.profile(False)
@cython.boundscheck(False)
@cython.wraparound(False)
def fft(np.ndarray[double complex, ndim=2] A, float r, float t, float T, 

           np.ndarray[double complex, ndim=2] C, np.ndarray[double complex, ndim=2] D, 

                np.ndarray[double, ndim=2] V, float alphaC, np.ndarray[double complex, ndim=1] vJT, 

                    np.ndarray[double, ndim=1] k, float b, 

                        np.ndarray[long, ndim=1] strikes, float eta,

                                    np.ndarray[long, ndim=1] weightsT):

    cdef int M = V.shape[0]
    cdef int N = V.shape[1]
    cdef int NFFT = D.shape[1]
    cdef np.ndarray[double complex, ndim=1] F_cT
    cdef np.ndarray[double complex, ndim=2] toFFT = np.empty( (N, NFFT), dtype=complex)
    cdef np.ndarray[double, ndim=2] prices
    cdef float x_1, x_2, position
    cdef np.ndarray[double, ndim=1] y_1
    cdef np.ndarray[double, ndim=1] y_2 
    cdef np.ndarray[double, ndim=1] dydx
    cdef int m, j, n
    cdef np.ndarray[double, ndim=2] price = np.empty( (M * N, len(strikes)) )
    cdef np.ndarray[double complex, ndim=1] A_inp = A[0, :]

    for j in range( len(strikes) ):

        position = (math.log(strikes[j]) + b) / ( 2 * b / NFFT)

        x_1 = math.exp ( k[ int(math.floor(position)) ] )
        x_2 = math.exp ( k[ int(math.ceil(position)) ] )

        for m in range(M):

            """
            M is the number of rows we have in A, C, D and V, so we need to loop over all of those.
            """

            for n in range(N):

                """
                Next we loop over all of the elements for each row in V, corresponding to N. For
                us this corresponds to 1000 (if you haven't changed to N in the main program).

                Each of the rows of A, C and D are 2048 in length. So I tried to loop over all of 
                those as well as for each n, but this made the code 4 times slower.
                """

                F_cT = math.exp( -r*(T - t) ) * np.exp (A_inp + C[m, :] + D[m, :] * V[m, n]) / \
                       ( (alphaC + vJT) * (alphaC + 1 + vJT) )

                toFFT[n, :] = np.exp (b * vJT) * F_cT * eta / 3 * weightsT

            """
            I'm guessing FFT'ing is rather slow using NumPy in Cython?
            """

            prices = np.exp ( -alphaC * k ) / math.pi * np.real ( np.fft.fft(toFFT) )

            y_1 = prices[ :, int(math.floor(position)) ]
            y_2 = prices[ :, int(math.ceil(position)) ]

            dydx = (y_2 - y_1)/(x_2 - x_1)  

            price[m * N:(m + 1) * N, j] = dydx * (strikes[j] - x_1) + y_1

    return price

我将代码编译为

from distutils.core import setup, Extension
from Cython.Distutils import build_ext
import numpy.distutils.misc_util

include_dirs = numpy.distutils.misc_util.get_numpy_include_dirs()


setup(
  name = 'fftInC',
  ext_modules = [Extension('fftInC', ['fftInC.pyx'], include_dirs=include_dirs)],
  cmdclass = {'build_ext':build_ext}
)

但令我惊讶的是,Cython版本比原版慢约3倍。而且我无法弄清楚我哪里出错了。我认为我已经正确定义了输入类型(我理解这应该会给性能带来相当大的提升)。

因此,我的问题是:你能确定我出错的地方吗?是类型定义,for循环还是FFT(或其他)?

0 个答案:

没有答案