如何创建一个接受numpy数组,可迭代数组或标量的numpy函数?

时间:2012-09-29 13:31:56

标签: python arrays numpy

假设我有这个:

def incrementElements(x):
   return x+1

但是我想修改它以便它可以采用numpy数组,可迭代数组或标量数,并将参数提升为numpy数组并为每个元素添加1。

我怎么能这样做?我想我可以测试参数类,但这似乎是一个坏主意。如果我这样做:

def incrementElements(x):
   return numpy.array(x)+1

它适用于数组或迭代,但不适用于标量。这里的问题是标量x的numpy.array(x)产生一些奇怪的对象,它由一个numpy数组包含但不是一个“真正的”数组;如果我添加一个标量,结果会降级为标量。

4 个答案:

答案 0 :(得分:9)

你可以尝试

def incrementElements(x):
    x = np.asarray(x)
    return x+1

np.asarray(x)相当于np.array(x, copy=False),意味着标量或可迭代将转换为ndarray,但如果x已经是ndarray },其数据不会被复制。

如果您传递标量并希望ndarray作为输出(不是标量),则可以使用:

def incrementElements(x):
    x = np.array(x, copy=False, ndmin=1)
    return x

ndmin=1参数将强制数组至少有一个维度。使用ndmin=2至少2个维度,依此类推。您还可以使用等效的np.atleast_1d(或np.atleast_2d作为2D版本...)

答案 1 :(得分:2)

只要您的函数专门使用ufunc(或类似的东西)隐式循环输入值,Pierre GM的答案就很棒。如果你的函数需要迭代输入,那么np.asarray做得不够,因为你不能迭代NumPy标量:

import numpy as np

x = np.asarray(1)
for xval in x:
    print(np.exp(xval))

Traceback (most recent call last):
  File "Untitled 2.py", line 4, in <module>
    for xval in x:
TypeError: iteration over a 0-d array

如果您的函数需要迭代输入,则使用np.atleast_1dnp.squeeze(参见Array manipulation routines — NumPy Manual)可以使用以下内容。我包含了一个aaout(“Always Array OUT”)arg,因此您可以指定是否需要标量输入来生成单元素数组输出;如果不需要它可以被删除:

def scalar_or_iter_in(x, aaout=False):
    """
    Gather function evaluations over scalar or iterable `x` values.

    aaout :: boolean
        "Always array output" flag:  If True, scalar input produces
        a 1-D, single-element array output.  If false, scalar input
        produces scalar output.
    """
    x = np.asarray(x)
    scalar_in = x.ndim==0

    # Could use np.array instead of np.atleast_1d, as follows:
    # xvals = np.array(x, copy=False, ndmin=1)
    xvals = np.atleast_1d(x)
    y = np.empty_like(xvals, dtype=float)  # dtype in case input is ints
    for i, xx in enumerate(xvals):
        y[i] = np.exp(xx)  # YOUR OPERATIONS HERE!

    if scalar_in and not aaout:
        return np.squeeze(y)
    else:
        return y


print(scalar_or_iter_in(1.))
print(scalar_or_iter_in(1., aaout=True))
print(scalar_or_iter_in([1,2,3]))


2.718281828459045
[2.71828183]
[ 2.71828183  7.3890561  20.08553692]

当然,对于取幂,你不应该像这里那样显式迭代,但是使用NumPy ufuncs可能无法表达更复杂的操作。如果需要迭代,但想要对标量输入是否产生单元素数组输出进行类似控制,则函数的中间部分可能更简单,但返回必须处理{{1} }:

np.atleast_1d

我怀疑在大多数情况下,def scalar_or_iter_in(x, aaout=False): """ Gather function evaluations over scalar or iterable `x` values. aaout :: boolean "Always array output" flag: If True, scalar input produces a 1-D, single-element array output. If false, scalar input produces scalar output. """ x = np.asarray(x) scalar_in = x.ndim==0 y = np.exp(x) # YOUR OPERATIONS HERE! if scalar_in and not aaout: return np.squeeze(y) else: return np.atleast_1d(y) 标志不是必需的,并且您总是希望标量输出带有标量输入。在这种情况下,返回应该是:

aaout

答案 2 :(得分:0)

这是一个老问题,但这是我的两分钱。

尽管皮尔·通用的答案很有效,但它有将---将标量转换为阵列的副作用。如果这是你想要/需要的,那就停止阅读;否则,继续。虽然这可能没问题(并且可能对lists和其他iterables返回np.array有好处),但可以认为对于标量它应该返回标量。如果这是所期望的行为,为什么不遵循python的EAFP哲学。这就是我通常做的事情(我更改了示例以显示np.asarray返回&#34;标量&#34;)时会发生什么:

def saturateElements(x):
    x = np.asarray(x)
    try:
        x[x>1] = 1
    except TypeError:
        x = min(x,1)
    return x

我意识到它比Pierre GM的答案更冗长,但正如我所说,如果传递标量,或者np.array是数组或者可迭代的,则此解决方案将返回标量。

答案 3 :(得分:0)

现有解决方案都可以使用。这是另一个选择。

def incrementElements(x):
    try:
        iter(x)
    except TypeError:
        x = [x]
    x = np.array(x)
    # code that operators on x