如何正确地将numpy逻辑函数传递给Cython?

时间:2018-02-17 20:31:07

标签: python numpy optimization cython

我应该在逻辑函数/索引操作中加入哪些声明,以便Cython完成繁重的工作?

我有两个相同大小的numpy数组形式的大型栅格。第一个数组包含植被索引值,第二个数组包含字段ID。目标是按字段平均植被指数值。两个数组都有令人讨厌的nodata值(-9999),我想忽略它们。

目前该功能需要60秒才能执行,这通常我不介意,但我会处理数百张图像。即使是30秒的改进也很重要。所以我一直在探索Cython作为一种帮助加快速度的方法。我一直在使用Cython numpy tutorial作为指南。

Example data

test_cy.pyx代码:

import numpy as np
cimport numpy as np
cimport cython
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function 

cpdef test():
  cdef np.ndarray[np.int16_t, ndim=2] ndvi_array = np.load("Z:cython_test/data/ndvi.npy")

  cdef np.ndarray[np.int16_t, ndim=2] field_array = np.load("Z:cython_test/data/field_array.npy")

  cdef np.ndarray[np.int16_t, ndim=1] unique_field = np.unique(field_array)
  unique_field = unique_field[unique_field != -9999]

  cdef int field_id
  cdef np.ndarray[np.int16_t, ndim=1] f_ndvi_values
  cdef double f_avg

  for field_id in unique_field :
      f_ndvi_values = ndvi_array[np.logical_and(field_array == field_id, ndvi_array != -9999)]
      f_avg = np.mean(f_ndvi_values)

Setup.py代码:

try:
    from setuptools import setup
    from setuptools import Extension
except ImportError:
    from distutils.core import setup
    from distutils.extension import Extension

from Cython.Build import cythonize
import numpy

setup(ext_modules = cythonize('test_cy.pyx'),
      include_dirs=[numpy.get_include()])

经过一番研究和运行:

cython -a test_cy.pyx

似乎索引操作ndvi_array[np.logical_and(field_array == field_id, ndvi_array != -9999)]是瓶颈,仍然依赖于Python。我怀疑我在这里遗漏了一些重要的声明。包括ndim没有任何影响。

我对numpy也很新,所以我可能会遗漏一些明显的东西。

1 个答案:

答案 0 :(得分:1)

你的问题对我来说看起来很容易上传,所以Cython可能不是最好的方法。 (当存在不可避免的细粒度循环时,Cython会发光。)由于您的dtype为int16,因此只有有限范围的可能标签,因此使用np.bincount应该相当有效。尝试类似的事情(这假设所有有效值都是> = 0,如果情况不是你必须转移 - 或者(更便宜)视图转换为uint16(因为我们没有做任何算术在应该是安全的标签上 - 在使用bincount之前):

mask = (ndvi_array != -9999) & (field_array != -9999)
nd = ndvi_array[mask]
fi = field_array[mask]
counts = np.bincount(fi, minlength=2**15)
sums = np.bincount(fi, nd, minlength=2**15)
valid = counts != 0
avgs = sums[valid] / counts[valid]