我试图通过使用cython来增强python中的一些计算... 在我的计算中,我将进行双循环或更多加上我不能总是使用numpy向量化,所以我需要用cython来增强python循环。
这里我对一些简单的计算进行了基准测试,结果表明,cython比使用numpy慢10倍。我确信numpy已经被优化到最大值,我怀疑我能够击败它的性能,但仍然是因为我做错了。建议?
import numpy as np
from histogram import distances
import time
REPEAT = 10
def printTime(message, t):
print "%s total: %.7f(s) --> average: %.7f(s) %.7f(Ms)"%(message, t, t/REPEAT, 1000000*t/REPEAT)
DATA = np.array( np.random.random((10000, 3)), dtype=np.float32)
POINT = np.array( np.random.random((1,3)), dtype=np.float32)
# numpy histogram
r = REPEAT
startTime = time.clock()
while r:
diff = (DATA-POINT)%1
diffNumpy = np.where(diff<0, diff+1, diff)
distNumpy = np.sqrt( np.add.reduce(diff**2,1) )
r-=1
printTime("numpy", time.clock()-startTime)
# cython test
r = REPEAT
startTime = time.clock()
while r:
distCython = distances(POINT, DATA)
r-=1
printTime("cython", time.clock()-startTime)
import numpy as np
import cython
cimport cython
cimport numpy as np
DTYPE=np.float32
ctypedef np.float32_t DTYPE_C
@cython.nonecheck(False)
@cython.boundscheck(False)
@cython.wraparound(False)
def distances(np.ndarray[DTYPE_C, ndim=2] point, np.ndarray[DTYPE_C, ndim=2] data):
# declare variables
cdef int i
cdef float x,y,z
cdef np.ndarray[DTYPE_C, mode="c", ndim=1] dist = np.empty((data.shape[0]), dtype=DTYPE)
# loop
for i from 0 <= i < data.shape[0]:
# calculate distance
x = (data[i,0]-point[0,0])%1
y = (data[i,1]-point[0,1])%1
z = (data[i,2]-point[0,2])%1
# fold between 0 and 1
if x<0: x+=1
if y<0: y+=1
if z<0: z+=1
# assign to array
dist[i] = np.sqrt(x**2+y**2+z**2)
return dist
from distutils.core import setup
from Cython.Build import cythonize
import numpy as np
setup(
ext_modules = cythonize("histogram.pyx"),
include_dirs=[np.get_include()]
)
编译执行以下操作 python setup.py build_ext --inplace
推出Benchmarch python test.py
numpy total: 0.0153390(s) --> average: 0.0015339(s) 1533.9000000(Ms)
cython total: 0.1509920(s) --> average: 0.0150992(s) 15099.2000000(Ms)
答案 0 :(得分:2)
你的问题几乎肯定是
np.sqrt(x**2+y**2+z**2)
您应该使用C sqrt
功能。它看起来像
from libc.math cimport sqrt
sqrt(x*x + y*y + z*z)