比较numba编译函数中的字符串

时间:2017-10-12 11:42:11

标签: python string numba

我正在寻找比较使用numba jit编译的python函数中的字符串的最佳方法(没有python模式,python 3)。

用例如下:

import numba as nb

@nb.jit(nopython = True, cache = True)
def foo(a, t = 'default'):
    if t == 'awesome':
        return(a**2)
    elif t == 'default':
        return(a**3)
    else:
        ...

但是,会返回以下错误:

Invalid usage of == with parameters (str, const('awesome'))

我尝试使用字节,但无法成功。

谢谢!

莫里斯指出问题Python: can numba work with arrays of strings in nopython mode?,但我正在查看原生python,而不是numba支持的numpy子集。

2 个答案:

答案 0 :(得分:3)

Numba不支持nopython模式下的字符串。

来自documentation

  

2.6.2。内置类型

     

2.6.2.1。 int,bool [...]

     

2.6.2.2。漂浮,复杂[...]

     

2.6.2.3。元组[...]

     

2.6.2.4。列表[...]

     

2.6.2.5。设置[...]

     

2.6.2.7。 bytes,bytearray,memoryview

     

bytearray类型,在Python 3上,bytes类型支持索引,迭代和检索len()

     

[...]

因此根本不支持字符串,字节不支持相等性检查。

但是你可以传入bytes并迭代它们。这使得编写自己的比较函数成为可能:

import numba as nb

@nb.njit
def bytes_equal(a, b):
    if len(a) != len(b):
        return False
    for char1, char2 in zip(a, b):
        if char1 != char2:
            return False
    return True

不幸的是,下一个问题是numba无法“降低”字节,因此您无法直接对函数中的字节进行硬编码。但是字节基本上只是整数,bytes_equal函数适用于numba支持的所有类型,它们具有长度并且可以迭代。所以你可以简单地将它们存储为列表:

import numba as nb

@nb.njit
def foo(a, t):
    if bytes_equal(t, [97, 119, 101, 115, 111, 109, 101]):
        return a**2
    elif bytes_equal(t, [100, 101, 102, 97, 117, 108, 116]):
        return a**3
    else:
        return a

或作为全局数组(感谢@chrisb - 请参阅注释):

import numba as nb
import numpy as np

AWESOME = np.frombuffer(b'awesome', dtype='uint8')
DEFAULT = np.frombuffer(b'default', dtype='uint8')

@nb.njit
def foo(a, t):
    if bytes_equal(t, AWESOME):
        return a**2
    elif bytes_equal(t, DEFAULT):
        return a**3
    else:
        return a

两者都能正常工作:

>>> foo(10, b'default')
1000
>>> foo(10, b'awesome')
100
>>> foo(10, b'awe')
10

但是,您不能将bytes数组指定为默认值,因此需要显式提供t变量。这样做也感觉很麻烦。

我的观点:只需在正常函数中执行if t == ...检查,并在if内调用专门的numba函数。字符串比较在Python中非常快,只需将数学/数组密集型内容包装在numba函数中:

import numba as nb

@nb.njit
def awesome_func(a):
    return a**2

@nb.njit
def default_func(a):
    return a**3

@nb.njit
def other_func(a):
    return a

def foo(a, t='default'):
    if t == 'awesome':
        return awesome_func(a)
    elif t == 'default':
        return default_func(a)
    else:
        return other_func(a)

但请确保您确实需要numba功能。有时普通的Python / NumPy足够快。只需简要介绍numba解决方案和Python / NumPy解决方案,看看numba是否能让它显着提高速度。 :)

答案 1 :(得分:0)

我建议接受@ MSeifert的回答,但作为这类问题的另一种选择,请考虑使用enum

在python中,字符串通常用作一种枚举,而numba内置了对枚举的支持,因此可以直接使用它们。

import enum

class FooOptions(enum.Enum):
    AWESOME = 1
    DEFAULT = 2

import numba

@numba.njit
def foo(a, t=FooOptions.DEFAULT):
    if t == FooOptions.AWESOME:
        return a**2
    elif t == FooOptions.DEFAULT:
        return a**2
    else:
        return a

foo(10, FooOptions.AWESOME)
Out[5]: 100