具有if语句的Numpy数组的功能

时间:2011-11-07 13:04:12

标签: python numpy matplotlib

我正在使用MatplotlibNumpy制作一些情节。我希望定义一个函数,给定一个数组返回另一个数组,其值为 elementwise ,例如:

def func(x):
     return x*10

x = numpy.arrange(-1,1,0.01)
y = func(x)

这很好。现在我希望在func中有一个if语句,例如:

def func(x):
     if x<0:
          return 0
     else:
          return x*10

x = numpy.arrange(-1,1,0.01)
y = func(x)

不幸的是,这会引发以下错误

Traceback (most recent call last):
  File "D:\Scripts\test.py", line 17, in <module>
    y = func(x)
  File "D:\Scripts\test.py", line 11, in func
    if x<0:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我查看了all()any()的文档,但它们并不符合我的需求。那么有没有一种很好的方法可以像第一个例子那样使函数处理数组元素?

6 个答案:

答案 0 :(得分:13)

在将func应用于数组x之前使用numpy.vectorize来包装func:

from numpy import vectorize
vfunc = vectorize(func)
y = vfunc(x)

答案 1 :(得分:11)

我知道这个答案为时已晚,但我很高兴学习NumPy。您可以使用numpy.where。

自行向量化该函数
def func(x):
    import numpy as np
    x = np.where(x<0, 0., x*10)
    return x   

<强>实施例

使用标量作为数据输入:

x = 10
y = func(10)
y = array(100.0)

使用数组作为数据输入:

x = np.arange(-1,1,0.1)
y = func(x)
y = array([ -1.00000000e+00,  -9.00000000e-01,  -8.00000000e-01,
    -7.00000000e-01,  -6.00000000e-01,  -5.00000000e-01,
    -4.00000000e-01,  -3.00000000e-01,  -2.00000000e-01,
    -1.00000000e-01,  -2.22044605e-16,   1.00000000e-01,
     2.00000000e-01,   3.00000000e-01,   4.00000000e-01,
     5.00000000e-01,   6.00000000e-01,   7.00000000e-01,
     8.00000000e-01,   9.00000000e-01])

<强>注意事项

1)如果x是一个屏蔽数组,则需要使用np.ma.where,因为这适用于屏蔽数组。

答案 2 :(得分:10)

这应该做你想要的:

def func(x):
    small_indices = x < 10
    x[small_indices] = 0
    x[invert(small_indices)] *= 10
    return x

invert是一个Numpy函数。请注意,这会修改参数。为防止这种情况发生,您必须修改并返回copy x

答案 3 :(得分:5)

(我意识到这是一个老问题,但是......)

此处未提及另一个选项 - 使用np.choose

np.choose(
    # the boolean condition
    x < 0,
    [
        # index 0: value if condition is False
        10 * x,
        # index 1: value if condition is True
        0
    ]
)

虽然不是非常易读,但这只是一个表达式(不是一系列语句),并且不会影响numpy的固有速度(如np.vectorize那样)。

答案 4 :(得分:1)

x = numpy.arrange(-1,1,0.01)
mask = x>=0
y = numpy.zeros(len(x))
y[mask] = x[mask]*10

mask是一个布尔数组,等于True是与条件匹配的数组索引和其他地方的False。最后一行替换原始数组中的所有值,该值多次为10。

编辑以反映Bjorn的相关评论

答案 5 :(得分:0)

不确定为什么需要功能

x = np.arange(-1, 1, 0.01)
y = x * np.where(x < 0, 0, 10)