使用numba` guvectorize`向量化返回元组的函数

时间:2019-08-26 17:52:24

标签: python vectorization numba

我试图向量化一个简单的函数,该函数使用guvectorize返回一个元组。显然,numba文档不包含guvectorize的任何有效示例,其中函数返回tuple

最初,我试图这样做:

z = (x+y, x-y)

然后根据stackoverflow答案将其更改为以下内容。

z[:] = (x+y, x-y)

尽管如此,我仍然发现错误,这些错误似乎对我来说很难破解。我想要的是向量化一个函数,该函数接受多个Samdimension数组,并返回具有与输入数组相同维的元组数组。例如,假设样本函数的输入数组为:

a = array([[4, 7, 9],
           [7, 1, 2]])
b = array([[5, 6, 6],
           [2, 5, 6]])

然后输出应为:

c = array([[ (9, -1), (13, 1), (15, 3)],
           [ (9, 5),  (6, -4),  (8, -4)]], dtype=object)

我的示例代码和错误如下:

from numba import void, float64, UniTuple, guvectorize
@guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)') 
def fun(x, y, z): 
    z[:] = (x+y, x-y)
<ipython-input-24-6920fb0e2a76>:2: NumbaWarning: 
Compilation is falling back to object mode WITHOUT looplifting enabled because Function "fun" failed type inference due to: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (tuple(array(float64, 1d, A) x 2), slice<a:b>, tuple(array(float64, 1d, C) x 2))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    All templates rejected with literals.
In definition 9:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of staticsetitem at <ipython-input-24-6920fb0e2a76> (4)

File "<ipython-input-24-6920fb0e2a76>", line 4:
def fun(x, y, z):
    z[:] = (x+y, x-y)
    ^

  @nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
/home/user/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler.py:742: NumbaWarning: Function "fun" was compiled in object mode without forceobj=True.

File "<ipython-input-24-6920fb0e2a76>", line 3:
@nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
def fun(x, y, z):
^

  self.func_ir.loc))
/home/user/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler.py:751: NumbaDeprecationWarning: 
Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.

For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit

File "<ipython-input-24-6920fb0e2a76>", line 3:
@nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
def fun(x, y, z):
^

  warnings.warn(errors.NumbaDeprecationWarning(msg, self.func_ir.loc))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-6920fb0e2a76> in <module>
      1 from numba.types import UniTuple
----> 2 @nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
      3 def fun(x, y, z):
      4     z[:] = (x+y, x-y)

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/decorators.py in wrap(func)
    178         for fty in ftylist:
    179             guvec.add(fty)
--> 180         return guvec.build_ufunc()
    181 
    182     return wrap

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/ufuncbuilder.py in build_ufunc(self)
    304         for sig in self._sigs:
    305             cres = self._cres[sig]
--> 306             dtypenums, ptr, env = self.build(cres)
    307             dtypelist.append(dtypenums)
    308             ptrlist.append(utils.longint(ptr))

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/ufuncbuilder.py in build(self, cres)
    328         info = build_gufunc_wrapper(
    329             self.py_func, cres, self.sin, self.sout,
--> 330             cache=self.cache, is_parfors=False,
    331         )
    332 

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in build_gufunc_wrapper(py_func, cres, sin, sout, cache, is_parfors)
    501                else _GufuncWrapper)
    502     return wrapcls(
--> 503         py_func, cres, sin, sout, cache, is_parfors=is_parfors,
    504     ).build()
    505 

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in build(self)
    454     def build(self):
    455         wrapper_name = "__gufunc__." + self.fndesc.mangled_name
--> 456         wrapperlib = self._compile_wrapper(wrapper_name)
    457         return _wrapper_info(
    458             library=wrapperlib, env=self.env, name=wrapper_name,

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in _compile_wrapper(self, wrapper_name)
    445                 wrapperlib.enable_object_caching()
    446                 # Build wrapper
--> 447                 self._build_wrapper(wrapperlib, wrapper_name)
    448                 # Cache
    449                 self.cache.save_overload(self.cres.signature, wrapperlib)

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in _build_wrapper(self, library, name)
    399                                            self.sin + self.sout)):
    400             ary = GUArrayArg(self.context, builder, arg_args,
--> 401                              arg_steps, i, step_offset, typ, sym, sym_dim)
    402             step_offset += len(sym)
    403             arrays.append(ary)

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in __init__(self, context, builder, args, steps, i, step_offset, typ, syms, sym_dim)
    656             if syms:
    657                 raise TypeError("scalar type {0} given for non scalar "
--> 658                                 "argument #{1}".format(typ, i + 1))
    659             self._loader = _ScalarArgLoader(dtype=typ, stride=core_step)
    660 

TypeError: scalar type tuple(array(float64, 1d, A) x 2) given for non scalar argument #3

2 个答案:

答案 0 :(得分:1)

这里是一个Numba示例,返回了2个2维NumPy数组的元组。 在这种情况下,我认为您可以在NumPy中使用加法和减法(如果有两个数组就可以了),但是我在这里添加了一个Numba的工作示例。我之所以以以下方式应用装饰器,是因为我觉得很方便,但是如果您希望改回典型方式,则完全等效。

import numpy as np

try:
    from numba import jit, prange
except ImportError:
    numba_opt = False
else:
    numba_opt = True

a = np.array([[4, 7, 9],
             [7, 1, 2]], dtype=float)
b = np.array([[5, 6, 6],
             [2, 5, 6]], dtype=float)

def numba_function(a: np.ndarray, b: np.ndarray):
    l0 = np.shape(a)[0]
    l1 = np.shape(a)[1]
    p = np.zeros_like(a)
    m = np.zeros_like(a)
    for i in range(l0):
        for j in range(l1):
            p[i, j] = a[i, j] + b[i, j]
            m[i, j] = a[i, j] - b[i, j]
    return(p, m)

if numba_opt:
    fun_rec = jit(signature_or_function='UniTuple(float64[:,:],2)(float64[:,:],float64[:,:])',
                  nopython=True, parallel=False, cache=True, fastmath=True, nogil=True)(numba_function)


p, m = fun_rec(a, b)
print(p)
print(m)

答案 1 :(得分:0)

这似乎按预期工作:

@guvectorize([void(float64[:], float64[:], float64[:], float64[:])], '(n), (n) -> (n), (n)')
def fun(x, y, addition, subtraction):
    addition[:] = x + y
    subtraction[:] = x - y

例如:

>>> a = np.array([1., 2., 3.])
>>> b = np.array([-1., 4., 2.])
>>> fun(a, b)
(array([0., 6., 5.]), array([ 2., -2.,  1.]))