我想使用cython加速以下代码:
class A(object):
cdef fun(self):
return 3
class B(object):
cdef fun(self):
return 2
def test():
cdef int x, y, i, s = 0
a = [ [A(), B()], [B(), A()]]
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
s += a[x][y].fun()
return s
唯一想到的是这样的事情:
def test():
cdef int x, y, i, s = 0
types = [ [0, 1], [1, 0]]
data = [[...], [...]]
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
if types[x,y] == 0:
s+= A(data[x,y]).fun()
else:
s+= B(data[x,y]).fun()
return s
基本上,C ++中的解决方案是使用虚方法fun()
获得指向某个基类的指针数组,然后您可以很快地迭代它。有没有办法使用python / cython?
答案 0 :(得分:7)
看起来像这样的代码提供了大约20倍的加速:
import numpy as np
cimport numpy as np
cdef class Base(object):
cdef int fun(self):
return -1
cdef class A(Base):
cdef int fun(self):
return 3
cdef class B(Base):
cdef int fun(self):
return 2
def test():
bbb = np.array([[A(), B()], [B(), A()]], dtype=np.object_)
cdef np.ndarray[dtype=object, ndim=2] a = bbb
cdef int i, x, y
cdef int s = 0
cdef Base u
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
u = a[x,y]
s += u.fun()
return s
它甚至检查A和B是否继承自Base,可能有一种方法可以在发布版本中禁用它并获得额外的加速
编辑:可以使用
删除检查u = <Base>a[x,y]