加速cython代码

时间:2015-02-02 10:51:23

标签: python c performance cython

我编写了一个管理大量数据的python代码,因此需要花费大量时间。所以,我发现了Cython,我开始改变我的代码。

基本上,我所做的只是改变功能'声明(cdef类型名称(带有变量类型的参数)),声明带有类型的cdef变量,以及声明cdef类。 我用eclipse写了所有.pyx,然后用命令python setup.py build_ext --inplace编译并用eclipse运行它。

我的问题是将python与cython速度进行比较,没有任何区别。

我运行命令cython -a <file>来生成一个html文件,并且有很多黄线。

我不知道我做错了什么,我应该包含其他内容,而且我不知道如何删除这些黄线。

我只是粘贴一些代码行,这是我想要加速的部分,因为代码很长。


main.pyx

'''there are a lot of ndarray objects stored in a file and in this step I get each of them until there are no more items '''
cdef ReadWavePoints (WavePointManagement wavePointManagement, ColumnManagement columnManagement):
        cdef int runReadWavePoints

    wavePointManagement.OpenWavePointFileLoad(wavePointsFile)
    runReadWavePoints = 1

    while runReadWavePoints == 1:
        try:
            wavePointManagement.LoadWavePointFile()
            wavePointManagement.RoundCoordinates()
            wavePointManagement.SortWavePointList()
            GroupColumnsVoxels(wavePointManagement.GetWavePointList(), columnManagement)
        except:
            wavePointManagement.CloseWavePointFile()
            columnManagement.CloseWriteColumnFile()
            break

'''I check which points are in the same XYZ (voxel) and in the same XY (column)'''

cdef GroupColumnsVoxels (object wavePointList, ColumnManagement columnManagement):
    cdef int indexWavePointRef, indexWavePoint
    cdef int saved
    cdef double voxelValue
    cdef int sizeWavePointList

    sizeWavePointList = len(wavePointList)

    indexWavePointRef = 0

    while indexWavePointRef < sizeWavePointList - 1:
        saved = 0
        voxelValue = (wavePointList[indexWavePointRef]).GetValue()
        for indexWavePoint in xrange(indexWavePointRef + 1, len(wavePointList)):
            if (wavePointList[indexWavePointRef]).GetX() == (wavePointList[indexWavePoint]).GetX() and (wavePointList[indexWavePointRef]).GetY() == (wavePointList[indexWavePoint]).GetY():
                if (wavePointList[indexWavePointRef]).GetZ() == (wavePointList[indexWavePoint]).GetZ():
                    if voxelValue < (wavePointList[indexWavePoint]).GetValue():
                        voxelValue = (wavePointList[indexWavePoint]).GetValue()
                else:
                    saved = 1
                    CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
                    indexWavePointRef = indexWavePoint
                    if indexWavePointRef == sizeWavePointList - 1:
                        CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), (wavePointList[indexWavePointRef]).GetValue())
                    break
            else:
                saved = 1
                CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
                columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
                columnManagement.AddColumn(columnObject)
                MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ()) 
                indexWavePointRef = indexWavePoint
                break
        if saved == 0:
            CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
            indexWavePointRef = indexWavePoint
    columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
    columnManagement.AddColumn(columnObject)
    MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())



'''I check if the data stored in a voxel is lower than the new one; if its the case, I store it'''  

cdef CheckVoxel (double X, double Y, double Z, double newValue):
    cdef object bandVoxel, structvalCheckVoxel, out_str
    cdef tuple valueCheckVoxel

    bandVoxel = datasetVoxels.GetRasterBand(int(math.floor(Z/0.3))+1)
    structvalCheckVoxel = bandVoxel.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
    valueCheckVoxel = struct.unpack('f', structvalCheckVoxel)

    if newValue > valueCheckVoxel[0]:
        out_str = struct.pack('f', newValue)
        bandVoxel.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_str)

'''I check if this point has the highest Z and I store this information'''    
cdef MaximumHeightColumn(double X, double Y, double newZ):
        cdef object bandMetricMaximumHeightColumn, structvalMaximumHeightColumn, out_strMaximumHeightColumn
    cdef tuple valueMaximumHeightColumn

    bandMetricMaximumHeightColumn = datasetMetrics.GetRasterBand(10)
    structvalMaximumHeightColumn = bandMetricMaximumHeightColumn.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
    valueMaximumHeightColumn = struct.unpack('f', structvalMaximumHeightColumn)

    if newZ > round(valueMaximumHeightColumn[0], 1):
        out_strMaximumHeightColumn = struct.pack('f', newZ)
        bandMetricMaximumHeightColumn.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_strMaximumHeightColumn)

WavePointManagement.pyx

'''this class serializes, rounds and sorts the points of each ndarray'''
import cPickle as pickle
import numpy as np
cimport numpy as np
import math

cdef class WavePointManagement(object):
    '''
    This class manages all the points extracted from the waveform
    '''
    cdef object fileObject, wavePointList
    __slots__ = ('wavePointList', 'fileObject')

    def __cinit__(self):
        '''
        Constructor
        '''

        self.fileObject = None
        self.wavePointList = np.array([])

    cdef object GetWavePointList(self):
        return self.wavePointList

    cdef void OpenWavePointFileLoad (self, object fileName):
        self.fileObject = file(fileName, 'rb')

    cdef void LoadWavePointFile (self):
        self.wavePointList = None
        self.wavePointList = pickle.load(self.fileObject)

    cdef void SortWavePointList (self):
        self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))

    cdef void RoundCoordinates (self):
        cdef int indexPointObject, sizeWavePointList

        for pointObject in self.GetWavePointList():
            pointObject.SetX(round(math.floor(pointObject.GetX()/0.25)*0.25, 2))
            pointObject.SetY(round(math.ceil(pointObject.GetY()/0.25)*0.25, 2))
            pointObject.SetZ(round(math.floor(pointObject.GetZ()/0.3)*0.3, 1))

    cdef void CloseWavePointFile(self):
        self.fileObject.close()

setup.py

from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext

import numpy

ext = Extension("main", ["main.pyx"], include_dirs = [numpy.get_include()])

setup (ext_modules=[ext], 
       cmdclass = {'build_ext' : build_ext}
       )

test_cython.py

'''this is the file I run with eclipse after compiling'''
from main import main

main()

我怎样才能加快这段代码的速度?

1 个答案:

答案 0 :(得分:3)

您的代码在使用numpy数组和列表之间来回跳转。因此,cython将产生的代码之间几乎没有区别。

以下代码生成python列表,key函数也是纯python函数。

self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))

如果您不想进行排序,则需要使用ndarray.sort(或numpy.sort。为此,您还需要更改对象在数组中的存储方式。也就是说,您需要使用structured array。有关如何对结构化数组进行排序的示例,请参阅numpy.sort - 特别是页面上的最后两个示例。

一旦你的数据存储在一个numpy数组中,你需要告诉cython数据如何存储在数组中。这包括提供类型信息和阵列的尺寸。 This page提供了有关如何有效使用numpy数组的更多信息。

创建和排序结构化数组的show示例:

import numpy as np
cimport numpy as np

DTYPE = [('name', 'S10'), ('height', np.float64), ('age', np.int32)]

cdef packed struct Person:
    char name[10]
    np.float64_t height
    np.int32_t age

ctypedef Person DTYPE_t

def create_array():
    values = [('Arthur', 1.8, 41), ('Lancelot', 1.9, 38),
              ('Galahad', 1.7, 38)]
    return np.array(values, dtype=DTYPE)

cpdef sort_by_age_then_height(np.ndarray[DTYPE_t, ndim=1] arr):
    arr.sort(order=['age', 'height'])  

最后,您需要将代码从使用python方法转换为使用标准c库方法以进一步加快速度。以下是使用RoundCoordinates的示例。 ``cpdef`表示函数也通过包装函数暴露给python。

cimport cython
cimport numpy as np
from libc.math cimport floor, ceil, round

import numpy as np

DTYPE = [('x', np.float64), ('y', np.float64), ('z', np.float64)]

cdef packed struct Point3D:
    np.float64_t x, y, z

ctypedef Point3D DTYPE_t

# Caution should be used when turning the bounds check off as it can lead to undefined 
# behaviour if you use an invalid index.
@cython.boundscheck(False)
cpdef RoundCoordinates_cy(np.ndarray[DTYPE_t] pointlist):
    cdef int i
    cdef DTYPE_t point
    for i in range(len(pointlist)): # this line is optimised into a c loop
        point = pointlist[i] # creates a copy of the point
        point.x = round(floor(point.x/0.25)*2.5) / 10
        point.y = round(ceil(point.y/0.25)*2.5) / 10
        point.z = round(floor(point.z/0.3)*3) / 10
        pointlist[i] = point # overwrites the old point data with the new data

最后,在重写整个代码库之前,您应该对代码进行分析,以查看程序花​​费大部分时间的功能,并优化这些功能,然后再去优化其他功能。