我应该在逻辑函数/索引操作中加入哪些声明,以便Cython完成繁重的工作?
我有两个相同大小的numpy数组形式的大型栅格。第一个数组包含植被索引值,第二个数组包含字段ID。目标是按字段平均植被指数值。两个数组都有令人讨厌的nodata值(-9999),我想忽略它们。
目前该功能需要60秒才能执行,这通常我不介意,但我会处理数百张图像。即使是30秒的改进也很重要。所以我一直在探索Cython作为一种帮助加快速度的方法。我一直在使用Cython numpy tutorial作为指南。
test_cy.pyx代码:
import numpy as np
cimport numpy as np
cimport cython
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False) # turn off negative index wrapping for entire function
cpdef test():
cdef np.ndarray[np.int16_t, ndim=2] ndvi_array = np.load("Z:cython_test/data/ndvi.npy")
cdef np.ndarray[np.int16_t, ndim=2] field_array = np.load("Z:cython_test/data/field_array.npy")
cdef np.ndarray[np.int16_t, ndim=1] unique_field = np.unique(field_array)
unique_field = unique_field[unique_field != -9999]
cdef int field_id
cdef np.ndarray[np.int16_t, ndim=1] f_ndvi_values
cdef double f_avg
for field_id in unique_field :
f_ndvi_values = ndvi_array[np.logical_and(field_array == field_id, ndvi_array != -9999)]
f_avg = np.mean(f_ndvi_values)
Setup.py代码:
try:
from setuptools import setup
from setuptools import Extension
except ImportError:
from distutils.core import setup
from distutils.extension import Extension
from Cython.Build import cythonize
import numpy
setup(ext_modules = cythonize('test_cy.pyx'),
include_dirs=[numpy.get_include()])
经过一番研究和运行:
cython -a test_cy.pyx
似乎索引操作ndvi_array[np.logical_and(field_array == field_id, ndvi_array != -9999)]
是瓶颈,仍然依赖于Python。我怀疑我在这里遗漏了一些重要的声明。包括ndim
没有任何影响。
我对numpy也很新,所以我可能会遗漏一些明显的东西。
答案 0 :(得分:1)
你的问题对我来说看起来很容易上传,所以Cython可能不是最好的方法。 (当存在不可避免的细粒度循环时,Cython会发光。)由于您的dtype为int16
,因此只有有限范围的可能标签,因此使用np.bincount
应该相当有效。尝试类似的事情(这假设所有有效值都是> = 0,如果情况不是你必须转移 - 或者(更便宜)视图转换为uint16
(因为我们没有做任何算术在应该是安全的标签上 - 在使用bincount
之前):
mask = (ndvi_array != -9999) & (field_array != -9999)
nd = ndvi_array[mask]
fi = field_array[mask]
counts = np.bincount(fi, minlength=2**15)
sums = np.bincount(fi, nd, minlength=2**15)
valid = counts != 0
avgs = sums[valid] / counts[valid]