从Cython中的列表中调用方法

时间:2017-11-12 03:07:32

标签: python parallel-processing cython cythonize

我想并行执行3个函数,在Cython的prange循环中使用相同的输入。它们在相同的变量TVdu上添加值,采用相同的变量。代码的目的是计算四个主方向上的像素梯度,然后按像素计算总变差。

为此,我使用方法名称创建一个列表并迭代此列表。我有这段代码:

cdef void TV_norm(float[:, :] ux, float[:, :] uy, float[:, :] output, float epsilon, float p) nogil:
    cdef int M = ux.shape[0]
    cdef int N = ux.shape[1]
    cdef int i, j
    cdef float inv_p = 1./p
    cdef float eps = epsilon**p

    with parallel(num_threads=64):
        for i in prange(M, schedule="guided"):
            for j in range(N):
                output[i, j] += (abs(ux[i, j])** p + abs(uy[i, j])** p + eps) **inv_p

cdef void center_diff(float[:, :] u, float[:, :] TV, float[:, :] du, int di, int dj, float epsilon, float p):
    ux = np.roll(u, (di, 0)) - u
    uy = np.roll(u, (0, dj)) - u
    TV_norm(ux, uy, TV, epsilon, p)
    du -= ux + uy


cdef void i_diff(float[:, :] u, float[:, :] TV, float[:, :] du, int di, int dj, float epsilon, float p):
    ux = u - np.roll(u, (-di, 0))
    uy = np.roll(u, (-di, dj)) - np.roll(u, (-di, 0))
    TV_norm(ux, uy, TV, epsilon, p)
    du += ux


cdef void j_diff(float[:, :] u, float[:, :] TV, float[:, :] du, int di, int dj, float epsilon, float p):
    ux = np.roll(u, (di, -dj)) - np.roll(u, (0, -dj))
    uy = u - np.roll(u, (0, -dj))
    TV_norm(ux, uy, TV, epsilon, p)
    du += uy


cdef list divTV_dual(float[:, :] u, float epsilon=0, float p=1):
    cdef np.ndarray[DTYPE_t, ndim=2] TV = np.zeros_like(u)
    cdef np.ndarray[DTYPE_t, ndim=2] du = TV.copy()
    cdef list shifts = [[1, 1],[-1, 1],[1,-1],[-1, -1]]
    cdef list methods = [center_diff, i_diff, j_diff]

    with nogil, parallel(num_threads=4):
        for i in prange(4, schedule="static"):
            with gil:
                di = shifts[i][0]
                dj = shifts[i][1]
                for j in range(3):
                    methods[j](u, TV, du, di, dj, epsilon, p)

    return [du, TV]

虽然它在纯Python中有效,但Cython在编译时失败了:

/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py in run_cell_magic(self, magic_name, line, cell)
   2129             magic_arg_s = self.var_expand(line, stack_depth)
   2130             with self.builtin_trap:
-> 2131                 result = fn(magic_arg_s, cell)
   2132             return result
   2133 

<decorator-gen-127> in cython(self, line, cell)

/usr/local/lib/python3.5/dist-packages/IPython/core/magic.py in <lambda>(f, *a, **k)
    185     # but it's overkill for just that one bit of state.
    186     def magic_deco(arg):
--> 187         call = lambda f, *a, **k: f(*a, **k)
    188 
    189         if callable(arg):

/usr/local/lib/python3.5/dist-packages/Cython/Build/IpythonMagic.py in cython(self, line, cell)
    289             build_extension.build_temp = os.path.dirname(pyx_file)
    290             build_extension.build_lib  = lib_dir
--> 291             build_extension.run()
    292             self._code_cache[key] = module_name
    293 

/usr/lib/python3.5/distutils/command/build_ext.py in run(self)
    336 
    337         # Now actually compile and link everything.
--> 338         self.build_extensions()
    339 
    340     def check_extensions_list(self, extensions):

/usr/lib/python3.5/distutils/command/build_ext.py in build_extensions(self)
    445             self._build_extensions_parallel()
    446         else:
--> 447             self._build_extensions_serial()
    448 
    449     def _build_extensions_parallel(self):

/usr/lib/python3.5/distutils/command/build_ext.py in _build_extensions_serial(self)
    470         for ext in self.extensions:
    471             with self._filter_build_errors(ext):
--> 472                 self.build_extension(ext)
    473 
    474     @contextlib.contextmanager

/usr/lib/python3.5/distutils/command/build_ext.py in build_extension(self, ext)
    530                                          debug=self.debug,
    531                                          extra_postargs=extra_args,
--> 532                                          depends=ext.depends)
    533 
    534         # XXX outdated variable, kept here in case third-part code

/usr/lib/python3.5/distutils/ccompiler.py in compile(self, sources, output_dir, macros, include_dirs, debug, extra_preargs, extra_postargs, depends)
    572             except KeyError:
    573                 continue
--> 574             self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
    575 
    576         # Return *all* object filenames, not just the ones we just built.

/usr/lib/python3.5/distutils/unixccompiler.py in _compile(self, obj, src, ext, cc_args, extra_postargs, pp_opts)
    118                        extra_postargs)
    119         except DistutilsExecError as msg:
--> 120             raise CompileError(msg)
    121 
    122     def create_static_lib(self, objects, output_libname,

CompileError: command 'x86_64-linux-gnu-gcc' failed with exit status 1

有什么意思吗?

编辑:

这个概念证明有效:

%%cython --compile-args=-O3 --compile-args=-ffast-math --compile-args=-fopenmp --link-args=-fopenmp

# cython: boundscheck=False
# cython: cdivision=True
# cython: wraparound=False
# cython: profile=True

cimport cython
from cython.parallel cimport parallel, prange

cdef foo(a):
    print(a)

cdef bar(a):
    print(a)

methods = [foo, bar]
cdef int i

with nogil, parallel():
    for i in prange(2):
        with gil:
            methods[i]("a")

1 个答案:

答案 0 :(得分:2)

找到它...列表中调用的方法应该使用cpdef而不是cdef来定义。

我想这是因为这些函数使用了numpy类型和方法,所以它们需要被python暴露。