此Cython函数返回在特定限制内的numpy数组元素中的随机元素:
cdef int search(np.ndarray[int] pool):
cdef np.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
这很好用。但是,此功能对于我的代码的性能非常关键。输入的内存视图显然比numpy数组快得多,但是不能以与上述相同的方式对其进行过滤。
如何使用键入的memoryviews编写一个与上述功能相同的函数?还是有另一种方法来提高功能的性能?
答案 0 :(得分:5)
好的,让我们开始使代码更通用,我稍后将介绍性能方面。
我通常不使用:
import numpy as np
cimport numpy as np
我个人喜欢为cimport
ed包使用一个不同的名称,因为它有助于使C端和NumPy-Python端保持分开。所以对于这个答案,我将使用
import numpy as np
cimport numpy as cnp
我还将创建函数的lower_limit
和upper_limit
参数。也许在您的情况下是静态(或全局)定义的,但是这使示例更加独立。因此,起点是您的代码的稍作修改的版本:
cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
cdef cnp.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
Cython的一个非常不错的功能是fused types,因此您可以轻松地将此功能概括为不同的类型。您的方法仅适用于32位整数数组(至少在计算机上int
为32位的情况下)。支持更多数组类型非常容易:
ctypedef fused int_or_float:
cnp.int32_t
cnp.int64_t
cnp.float32_t
cnp.float64_t
cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
cdef cnp.ndarray[int_or_float] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
当然,您可以根据需要添加更多类型。优点是新版本可以在旧版本失败的地方工作:
>>> search_1(np.arange(100, dtype=np.float_), 10, 20)
ValueError: Buffer dtype mismatch, expected 'int' but got 'double'
>>> search_2(np.arange(100, dtype=np.float_), 10, 20)
19.0
现在更笼统,让我们看一下您的函数实际执行的操作:
为什么要创建这么多数组?我的意思是,您可以简单地计算出限制内有多少个元素,取0到限制内的元素数之间的随机整数,然后取结果数组中该索引处的元素将要
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
cdef int_or_float element
# Count the number of elements that are within the limits
cdef Py_ssize_t num_valid = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
num_valid += 1
# Take a random index
cdef Py_ssize_t random_index = np.random.randint(0, num_valid)
# Go through the array again and take the element at the random index that
# is within the bounds
cdef Py_ssize_t clamped_index = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
if clamped_index == random_index:
return element
clamped_index += 1
它不会快很多,但是会节省很多内存。而且因为没有中间数组,所以根本不需要内存视图-但是如果愿意,可以将参数列表中的cnp.ndarray[int_or_float] arr
替换为int_or_float[:]
甚至是int_or_float[::1] arr
,在memoryview上运行(它可能不会更快,但也不会很慢)。
相对于Cython,我通常更喜欢numba(至少在我使用它的情况下),所以让我们将其与该代码的numba版本进行比较:
import numba as nb
import numpy as np
@nb.njit
def search_numba(arr, lower, upper):
num_valids = 0
for item in arr:
if item >= lower and item <= upper:
num_valids += 1
random_index = np.random.randint(0, num_valids)
valid_index = 0
for item in arr:
if item >= lower and item <= upper:
if valid_index == random_index:
return item
valid_index += 1
还有numexpr
的变体:
import numexpr
np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])
好的,让我们做一个基准测试
from simple_benchmark import benchmark, MultiArgument
arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]
b = benchmark(funcs, arguments, argument_name='array size')
因此,通过不使用中间数组,您的速度大约可以提高5倍,而如果使用numba,则可能会增加5倍(似乎我在这里缺少一些可能的Cython优化,numba通常会快2倍左右,或者像Cython一样快)。因此,使用numba解决方案可以使它快20倍左右。
numexpr
在这里并没有真正的可比性,主要是因为您不能在此处使用布尔数组索引。
差异将取决于数组的内容和限制。您还必须衡量应用程序的性能。
顺便说一句:如果下限和上限通常不改变,最快的解决方案是过滤一次数组,然后多次调用np.random.choice
。可能快了个数量级。
lower_limit = ...
upper_limit = ...
filtered_array = pool[(pool >= lower_limit) & (pool <= upper_limit)]
def search_cached():
return np.random.choice(filtered_array)
%timeit search_cached()
2.05 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
因此快了将近1000倍,根本不需要Cython或numba。但这是一种特殊情况,可能对您没有用。
如果您想自己动手做基准测试,请在此处(基于Jupyter笔记本/实验室,即%
-符号):
%load_ext cython
%%cython
cimport numpy as cnp
import numpy as np
cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
cdef cnp.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
ctypedef fused int_or_float:
cnp.int32_t
cnp.int64_t
cnp.float32_t
cnp.float64_t
cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
cdef cnp.ndarray[int_or_float] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
cdef int_or_float element
cdef Py_ssize_t num_valid = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
num_valid += 1
cdef Py_ssize_t random_index = np.random.randint(0, num_valid)
cdef Py_ssize_t clamped_index = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
if clamped_index == random_index:
return element
clamped_index += 1
import numexpr
import numba as nb
import numpy as np
def search_numexpr(arr, l, u):
return np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])
@nb.njit
def search_numba(arr, lower, upper):
num_valids = 0
for item in arr:
if item >= lower and item <= upper:
num_valids += 1
random_index = np.random.randint(0, num_valids)
valid_index = 0
for item in arr:
if item >= lower and item <= upper:
if valid_index == random_index:
return item
valid_index += 1
from simple_benchmark import benchmark, MultiArgument
arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]
b = benchmark(funcs, arguments, argument_name='array size')
%matplotlib widget
import matplotlib.pyplot as plt
plt.style.use('ggplot')
b.plot()