我编写了一个管理大量数据的python代码,因此需要花费大量时间。所以,我发现了Cython,我开始改变我的代码。
基本上,我所做的只是改变功能'声明(cdef类型名称(带有变量类型的参数)),声明带有类型的cdef变量,以及声明cdef类。
我用eclipse写了所有.pyx
,然后用命令python setup.py build_ext --inplace
编译并用eclipse运行它。
我的问题是将python与cython速度进行比较,没有任何区别。
我运行命令cython -a <file>
来生成一个html文件,并且有很多黄线。
我不知道我做错了什么,我应该包含其他内容,而且我不知道如何删除这些黄线。
我只是粘贴一些代码行,这是我想要加速的部分,因为代码很长。
'''there are a lot of ndarray objects stored in a file and in this step I get each of them until there are no more items '''
cdef ReadWavePoints (WavePointManagement wavePointManagement, ColumnManagement columnManagement):
cdef int runReadWavePoints
wavePointManagement.OpenWavePointFileLoad(wavePointsFile)
runReadWavePoints = 1
while runReadWavePoints == 1:
try:
wavePointManagement.LoadWavePointFile()
wavePointManagement.RoundCoordinates()
wavePointManagement.SortWavePointList()
GroupColumnsVoxels(wavePointManagement.GetWavePointList(), columnManagement)
except:
wavePointManagement.CloseWavePointFile()
columnManagement.CloseWriteColumnFile()
break
'''I check which points are in the same XYZ (voxel) and in the same XY (column)'''
cdef GroupColumnsVoxels (object wavePointList, ColumnManagement columnManagement):
cdef int indexWavePointRef, indexWavePoint
cdef int saved
cdef double voxelValue
cdef int sizeWavePointList
sizeWavePointList = len(wavePointList)
indexWavePointRef = 0
while indexWavePointRef < sizeWavePointList - 1:
saved = 0
voxelValue = (wavePointList[indexWavePointRef]).GetValue()
for indexWavePoint in xrange(indexWavePointRef + 1, len(wavePointList)):
if (wavePointList[indexWavePointRef]).GetX() == (wavePointList[indexWavePoint]).GetX() and (wavePointList[indexWavePointRef]).GetY() == (wavePointList[indexWavePoint]).GetY():
if (wavePointList[indexWavePointRef]).GetZ() == (wavePointList[indexWavePoint]).GetZ():
if voxelValue < (wavePointList[indexWavePoint]).GetValue():
voxelValue = (wavePointList[indexWavePoint]).GetValue()
else:
saved = 1
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
indexWavePointRef = indexWavePoint
if indexWavePointRef == sizeWavePointList - 1:
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), (wavePointList[indexWavePointRef]).GetValue())
break
else:
saved = 1
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
columnManagement.AddColumn(columnObject)
MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())
indexWavePointRef = indexWavePoint
break
if saved == 0:
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
indexWavePointRef = indexWavePoint
columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
columnManagement.AddColumn(columnObject)
MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())
'''I check if the data stored in a voxel is lower than the new one; if its the case, I store it'''
cdef CheckVoxel (double X, double Y, double Z, double newValue):
cdef object bandVoxel, structvalCheckVoxel, out_str
cdef tuple valueCheckVoxel
bandVoxel = datasetVoxels.GetRasterBand(int(math.floor(Z/0.3))+1)
structvalCheckVoxel = bandVoxel.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
valueCheckVoxel = struct.unpack('f', structvalCheckVoxel)
if newValue > valueCheckVoxel[0]:
out_str = struct.pack('f', newValue)
bandVoxel.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_str)
'''I check if this point has the highest Z and I store this information'''
cdef MaximumHeightColumn(double X, double Y, double newZ):
cdef object bandMetricMaximumHeightColumn, structvalMaximumHeightColumn, out_strMaximumHeightColumn
cdef tuple valueMaximumHeightColumn
bandMetricMaximumHeightColumn = datasetMetrics.GetRasterBand(10)
structvalMaximumHeightColumn = bandMetricMaximumHeightColumn.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
valueMaximumHeightColumn = struct.unpack('f', structvalMaximumHeightColumn)
if newZ > round(valueMaximumHeightColumn[0], 1):
out_strMaximumHeightColumn = struct.pack('f', newZ)
bandMetricMaximumHeightColumn.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_strMaximumHeightColumn)
'''this class serializes, rounds and sorts the points of each ndarray'''
import cPickle as pickle
import numpy as np
cimport numpy as np
import math
cdef class WavePointManagement(object):
'''
This class manages all the points extracted from the waveform
'''
cdef object fileObject, wavePointList
__slots__ = ('wavePointList', 'fileObject')
def __cinit__(self):
'''
Constructor
'''
self.fileObject = None
self.wavePointList = np.array([])
cdef object GetWavePointList(self):
return self.wavePointList
cdef void OpenWavePointFileLoad (self, object fileName):
self.fileObject = file(fileName, 'rb')
cdef void LoadWavePointFile (self):
self.wavePointList = None
self.wavePointList = pickle.load(self.fileObject)
cdef void SortWavePointList (self):
self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))
cdef void RoundCoordinates (self):
cdef int indexPointObject, sizeWavePointList
for pointObject in self.GetWavePointList():
pointObject.SetX(round(math.floor(pointObject.GetX()/0.25)*0.25, 2))
pointObject.SetY(round(math.ceil(pointObject.GetY()/0.25)*0.25, 2))
pointObject.SetZ(round(math.floor(pointObject.GetZ()/0.3)*0.3, 1))
cdef void CloseWavePointFile(self):
self.fileObject.close()
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
import numpy
ext = Extension("main", ["main.pyx"], include_dirs = [numpy.get_include()])
setup (ext_modules=[ext],
cmdclass = {'build_ext' : build_ext}
)
'''this is the file I run with eclipse after compiling'''
from main import main
main()
我怎样才能加快这段代码的速度?
答案 0 :(得分:3)
您的代码在使用numpy数组和列表之间来回跳转。因此,cython将产生的代码之间几乎没有区别。
以下代码生成python列表,key函数也是纯python函数。
self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))
如果您不想进行排序,则需要使用ndarray.sort
(或numpy.sort
。为此,您还需要更改对象在数组中的存储方式。也就是说,您需要使用structured array。有关如何对结构化数组进行排序的示例,请参阅numpy.sort
- 特别是页面上的最后两个示例。
一旦你的数据存储在一个numpy数组中,你需要告诉cython数据如何存储在数组中。这包括提供类型信息和阵列的尺寸。 This page提供了有关如何有效使用numpy数组的更多信息。
创建和排序结构化数组的show示例:
import numpy as np
cimport numpy as np
DTYPE = [('name', 'S10'), ('height', np.float64), ('age', np.int32)]
cdef packed struct Person:
char name[10]
np.float64_t height
np.int32_t age
ctypedef Person DTYPE_t
def create_array():
values = [('Arthur', 1.8, 41), ('Lancelot', 1.9, 38),
('Galahad', 1.7, 38)]
return np.array(values, dtype=DTYPE)
cpdef sort_by_age_then_height(np.ndarray[DTYPE_t, ndim=1] arr):
arr.sort(order=['age', 'height'])
最后,您需要将代码从使用python方法转换为使用标准c库方法以进一步加快速度。以下是使用RoundCoordinates
的示例。 ``cpdef`表示函数也通过包装函数暴露给python。
cimport cython
cimport numpy as np
from libc.math cimport floor, ceil, round
import numpy as np
DTYPE = [('x', np.float64), ('y', np.float64), ('z', np.float64)]
cdef packed struct Point3D:
np.float64_t x, y, z
ctypedef Point3D DTYPE_t
# Caution should be used when turning the bounds check off as it can lead to undefined
# behaviour if you use an invalid index.
@cython.boundscheck(False)
cpdef RoundCoordinates_cy(np.ndarray[DTYPE_t] pointlist):
cdef int i
cdef DTYPE_t point
for i in range(len(pointlist)): # this line is optimised into a c loop
point = pointlist[i] # creates a copy of the point
point.x = round(floor(point.x/0.25)*2.5) / 10
point.y = round(ceil(point.y/0.25)*2.5) / 10
point.z = round(floor(point.z/0.3)*3) / 10
pointlist[i] = point # overwrites the old point data with the new data
最后,在重写整个代码库之前,您应该对代码进行分析,以查看程序花费大部分时间的功能,并优化这些功能,然后再去优化其他功能。