我想做这样的事情:
import numpy as np
import matplotlib.pyplot as plt
xpts = np.linspace(0, 100, 1000)
test = lambda x: 0.5 if x > 66 else 1.0
plt.plot(xpts, test(xpts))
但是我收到了错误:
ValueError:具有多个元素的数组的真值 暧昧。使用a.any()或a.all()
另一方面,我能够做到:
print(test(50), test(70))
1.0 0.5
为什么会发生这种情况并且有解决方案?
答案 0 :(得分:2)
如果数组包含多个元素,则无法将数组转换为bool
:
In [21]: bool(np.array([1]))
Out[21]: True
In [22]: bool(np.array([1, 2]))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-22-5ba97928842c> in <module>()
----> 1 bool(np.array([1, 2]))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
您可能希望对数组中的每个元素应用test
函数:
In [23]: plt.plot(xpts, [test(x) for x in xpts])
Out[23]: [<matplotlib.lines.Line2D at 0x7fa560efeeb8>]
您也可以创建函数的矢量化版本并将其应用于数组,而不需要列表理解:
In [24]: test_v = np.vectorize(test)
In [25]: plt.plot(xpts, test_v(xpts))
Out[25]: [<matplotlib.lines.Line2D at 0x7fa560f19080>]
答案 1 :(得分:1)
Python列表不允许您对列表进行比较。所以你不能,例如,范围(10)&gt; 10.相反,您可以将输入转换为numpy数组并执行范围检查。 Ť
import numpy as np
import matplotlib.pyplot as plt
xpts = np.linspace(0, 100, 1000)
test = lambda x: (np.array(x) <= 66)*.5 + .5
print xpts, test(xpts)
plt.plot(xpts, test(xpts))
plt.show()