我试图在Python中用pyplot绘制一个函数,问题可以归结为:
import numpy as np
import matplotlib.pyplot as plt
def func(x):
if x<0:
return x*x
interval = np.arange(-4.0,4.0,0.1)
plt.plot(interval, func(interval))
plt.show()
会引发以下错误:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
我该如何避免这种情况?
答案 0 :(得分:3)
如果您想更改低于零的值,可以使用np.where
:
import numpy as np
def func(x):
return np.where(x < 0, x*x, x)
print(func(np.arange(-4, 5))) # array([16, 9, 4, 1, 0, 1, 2, 3, 4])
如果您只想将值低于零,则可以使用indexing with a boolean array:
def func(x):
return x[x<0] ** 2
print(func(np.arange(-4, 5))) # array([16, 9, 4, 1])
更通用:numpy.array
上的比较运算符只返回一个布尔数组:
>>> arr > 2
array([False, False, True], dtype=bool)
>>> arr == 2
array([False, True, False], dtype=bool)
异常
执行ValueError:具有多个元素的数组的真值是不明确的。使用a.any()或a.all()
bool(somearray)
时会发生。在很多情况下,bool()
调用是隐含的(所以发现它可能不是很明显)。此隐式bool()
来电的示例包括if
,while
,and
和or
:
>>> import numpy as np
>>> arr = np.array([1, 2, 3])
>>> bool(arr)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
>>> if arr: pass
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
>>> while arr: pass
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
>>> arr and arr
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
>>> arr or arr
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
在您的情况下,if x < 0
是异常的原因,因为x < 0
返回一个布尔数组,然后if
尝试获取该数组的bool()
。如上例所示,抛出了你得到的异常。
答案 1 :(得分:1)
首先,您需要确定大于或等于零的值应该发生什么。
所以让我们假设你有这个功能
def func(x):
if x<0:
return x*x
else:
return 2*x
现在,像func(np.arange(-4,4,0.1))
之类的东西不起作用,因为x的一半是正值,一半是负值。如果你问x是否为正数,那么答案就是&#34;它取决于......&#34;。这就是错误告诉你的。
因此,您需要确保该函数以元素方式处理输入数组。为此,您可以使用numpy.vectorize
。
func2 = np.vectorize(func)
interval = np.arange(-4.0,4.0,0.1)
plt.plot(interval, func2(interval))
然后绘制所需的结果。
您还可以决定直接编写一个函数,该函数接受一个数组作为输入。上面的例子可能看起来像
def func3(x):
return x*x*(x<0) + 2*x*(x>=0)