我必须使用python以数字方式求解微分方程。基本上我有两个不同的代码。一个负责阅读问题的初始条件,一个负责所有帐户。我想用cython优化第二个。 当我将常量的静态类型(dz,dt,i,k,j ..)定义为浮点数或整数时,我减少了四分之一的计算时间。现在,当我为numpy数组定义静态类型时,我没有任何改进。
这是我的代码(.pyx):
import numpy as np
cimport numpy as np
DTYPE = np.int
ctypedef np.int_t DTYPE_t
def explicit_cython(np.ndarray u, float kappa, float dt, float dz, np.ndarray term_const, unsigned int nz, plot_time):
'''Cython version of explicit method'''
#Defining C types
cdef unsigned int i, k, j
cdef unsigned int len_plot = len(plot_time) - 1
cdef float lamnda = kappa*dt/dz**2
u_out = []
u_out.append(u.copy())
for i in range(len_plot):
for k in range(plot_time[i], plot_time[i+1]):
un = u.copy()
for j in range(1, nz-1):
u[j] = un[j] + lamnda*(un[j+1] - 2*un[j] + un[j-1]) + term_const[j]
u_out.append(u.copy())
return u_out
这是我用来编译的设置。
from distutils.core import setup
from distutils.extension import Extension
from Cython.Build import cythonize
extensions=[Extension("explicit_cython2",["explicit_cython2.pyx"])]
setup(
ext_modules = cythonize(extensions)
)
当我python3 setup.py build_ext --inplace
时,请发出此警告:
In file included from /usr/include/numpy/ndarraytypes.h:1728:0,
from /usr/include/numpy/ndarrayobject.h:17,
from /usr/include/numpy/arrayobject.h:15,
from explicit_cython2.c:258:
/usr/include/numpy/npy_deprecated_api.h:11:2: warning: #warning "Using deprecated NumPy API, disable it by #defining NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]
#warning "Using deprecated NumPy API, disable it by #defining NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION"
为什么没有通过定义静态类型的numpy来赢得速度?为什么我有这个警告? THX!
PD。我在LMDE中使用python 3.4和Anaconda
答案 0 :(得分:2)
A)除非你能定义你的numpy数组的维度和interntal数据类型,否则你可能不会得到任何好处
def explicit_cython(np.ndarray[np.float_t,ndim=2],...
B)我认为弃用的警告是说新的更好的界面是类型化的内存视图http://docs.cython.org/src/userguide/memoryviews.html。如果你不想要那些,那就忽略它。
C)你可能会失去很多速度复制的东西,你会立即覆盖每一步,如果你可以做np.zeros(n.shape)
而你可能会获得一点点。 (或者甚至只是跳过内部for k
循环中的副本)。
D)循环的主要内容可以进行矢量化,无论如何都要避免使用Cython。