我试图向量化一个简单的函数,该函数使用guvectorize
返回一个元组。显然,numba
文档不包含guvectorize
的任何有效示例,其中函数返回tuple
。
最初,我试图这样做:
z = (x+y, x-y)
然后根据stackoverflow答案将其更改为以下内容。
z[:] = (x+y, x-y)
尽管如此,我仍然发现错误,这些错误似乎对我来说很难破解。我想要的是向量化一个函数,该函数接受多个Samdimension数组,并返回具有与输入数组相同维的元组数组。例如,假设样本函数的输入数组为:
a = array([[4, 7, 9],
[7, 1, 2]])
b = array([[5, 6, 6],
[2, 5, 6]])
然后输出应为:
c = array([[ (9, -1), (13, 1), (15, 3)],
[ (9, 5), (6, -4), (8, -4)]], dtype=object)
我的示例代码和错误如下:
from numba import void, float64, UniTuple, guvectorize
@guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
def fun(x, y, z):
z[:] = (x+y, x-y)
<ipython-input-24-6920fb0e2a76>:2: NumbaWarning:
Compilation is falling back to object mode WITHOUT looplifting enabled because Function "fun" failed type inference due to: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (tuple(array(float64, 1d, A) x 2), slice<a:b>, tuple(array(float64, 1d, C) x 2))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
All templates rejected with literals.
In definition 9:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of staticsetitem at <ipython-input-24-6920fb0e2a76> (4)
File "<ipython-input-24-6920fb0e2a76>", line 4:
def fun(x, y, z):
z[:] = (x+y, x-y)
^
@nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
/home/user/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler.py:742: NumbaWarning: Function "fun" was compiled in object mode without forceobj=True.
File "<ipython-input-24-6920fb0e2a76>", line 3:
@nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
def fun(x, y, z):
^
self.func_ir.loc))
/home/user/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler.py:751: NumbaDeprecationWarning:
Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.
For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit
File "<ipython-input-24-6920fb0e2a76>", line 3:
@nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
def fun(x, y, z):
^
warnings.warn(errors.NumbaDeprecationWarning(msg, self.func_ir.loc))
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-24-6920fb0e2a76> in <module>
1 from numba.types import UniTuple
----> 2 @nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
3 def fun(x, y, z):
4 z[:] = (x+y, x-y)
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/decorators.py in wrap(func)
178 for fty in ftylist:
179 guvec.add(fty)
--> 180 return guvec.build_ufunc()
181
182 return wrap
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
30 def _acquire_compile_lock(*args, **kwargs):
31 with self:
---> 32 return func(*args, **kwargs)
33 return _acquire_compile_lock
34
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/ufuncbuilder.py in build_ufunc(self)
304 for sig in self._sigs:
305 cres = self._cres[sig]
--> 306 dtypenums, ptr, env = self.build(cres)
307 dtypelist.append(dtypenums)
308 ptrlist.append(utils.longint(ptr))
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/ufuncbuilder.py in build(self, cres)
328 info = build_gufunc_wrapper(
329 self.py_func, cres, self.sin, self.sout,
--> 330 cache=self.cache, is_parfors=False,
331 )
332
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in build_gufunc_wrapper(py_func, cres, sin, sout, cache, is_parfors)
501 else _GufuncWrapper)
502 return wrapcls(
--> 503 py_func, cres, sin, sout, cache, is_parfors=is_parfors,
504 ).build()
505
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
30 def _acquire_compile_lock(*args, **kwargs):
31 with self:
---> 32 return func(*args, **kwargs)
33 return _acquire_compile_lock
34
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in build(self)
454 def build(self):
455 wrapper_name = "__gufunc__." + self.fndesc.mangled_name
--> 456 wrapperlib = self._compile_wrapper(wrapper_name)
457 return _wrapper_info(
458 library=wrapperlib, env=self.env, name=wrapper_name,
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in _compile_wrapper(self, wrapper_name)
445 wrapperlib.enable_object_caching()
446 # Build wrapper
--> 447 self._build_wrapper(wrapperlib, wrapper_name)
448 # Cache
449 self.cache.save_overload(self.cres.signature, wrapperlib)
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in _build_wrapper(self, library, name)
399 self.sin + self.sout)):
400 ary = GUArrayArg(self.context, builder, arg_args,
--> 401 arg_steps, i, step_offset, typ, sym, sym_dim)
402 step_offset += len(sym)
403 arrays.append(ary)
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in __init__(self, context, builder, args, steps, i, step_offset, typ, syms, sym_dim)
656 if syms:
657 raise TypeError("scalar type {0} given for non scalar "
--> 658 "argument #{1}".format(typ, i + 1))
659 self._loader = _ScalarArgLoader(dtype=typ, stride=core_step)
660
TypeError: scalar type tuple(array(float64, 1d, A) x 2) given for non scalar argument #3
答案 0 :(得分:1)
这里是一个Numba示例,返回了2个2维NumPy数组的元组。 在这种情况下,我认为您可以在NumPy中使用加法和减法(如果有两个数组就可以了),但是我在这里添加了一个Numba的工作示例。我之所以以以下方式应用装饰器,是因为我觉得很方便,但是如果您希望改回典型方式,则完全等效。
import numpy as np
try:
from numba import jit, prange
except ImportError:
numba_opt = False
else:
numba_opt = True
a = np.array([[4, 7, 9],
[7, 1, 2]], dtype=float)
b = np.array([[5, 6, 6],
[2, 5, 6]], dtype=float)
def numba_function(a: np.ndarray, b: np.ndarray):
l0 = np.shape(a)[0]
l1 = np.shape(a)[1]
p = np.zeros_like(a)
m = np.zeros_like(a)
for i in range(l0):
for j in range(l1):
p[i, j] = a[i, j] + b[i, j]
m[i, j] = a[i, j] - b[i, j]
return(p, m)
if numba_opt:
fun_rec = jit(signature_or_function='UniTuple(float64[:,:],2)(float64[:,:],float64[:,:])',
nopython=True, parallel=False, cache=True, fastmath=True, nogil=True)(numba_function)
p, m = fun_rec(a, b)
print(p)
print(m)
答案 1 :(得分:0)
这似乎按预期工作:
@guvectorize([void(float64[:], float64[:], float64[:], float64[:])], '(n), (n) -> (n), (n)')
def fun(x, y, addition, subtraction):
addition[:] = x + y
subtraction[:] = x - y
例如:
>>> a = np.array([1., 2., 3.])
>>> b = np.array([-1., 4., 2.])
>>> fun(a, b)
(array([0., 6., 5.]), array([ 2., -2., 1.]))