我试图使用matplotlib绘制一个简单的离散分布:
如何从x = np.linspace(-1, 2)
开始?
到目前为止我尝试的是:
def mapDiscProb(x):
if np.any(x < 0):
return 0.3 + x * 0
elif np.any(x >= 1):
return 0.2 + x * 0
else:
return 0.5 + x * 0
x = np.linspace(-1, 2)
y = mapDiscProb(x4)
ax.plot(x, y, clip_on = False)
结果只是从-1到2的整数线,好像elif
和else
没有被执行。
我的预期输出是三条断开的水平线,这是离散pmf的标准。
答案 0 :(得分:3)
您可以使用
numpy.piecewise
numpy.piecewise
允许根据某些条件定义函数。这里有三个条件[x<0, x>=1, (x>=0) & (x<1)]
,您可以定义一个用于每个条件的函数。
import matplotlib.pyplot as plt
import numpy as np
l1 = lambda x: 0.3 + x * 0
l2 = lambda x: 0.2 + x * 0
l3 = lambda x: 0.5 + x * 0
mapDiscProb=lambda x: np.piecewise(x, [x<0, x>=1, (x>=0) & (x<1)],[l1,l2,l3])
x = np.linspace(-1, 2)
y = mapDiscProb(x)
fig, ax = plt.subplots()
ax.plot(x, y, clip_on = False)
plt.show()
numpy.vectorize
numpy.vectorize
向量化一个函数,该函数用于使用标量调用,这样就可以对数组中的每个元素进行求值。这允许if
/ else
语句按预期使用。
import matplotlib.pyplot as plt
import numpy as np
def mapDiscProb(x):
if x < 0:
return 0.3
elif x >= 1:
return 0.2
else:
return 0.5
x = np.linspace(-1, 2)
y = np.vectorize(mapDiscProb)(x)
fig, ax = plt.subplots()
ax.plot(x, y, clip_on = False)
plt.show()
numpy.select
(感谢PaulH这个想法)numpy.select
可以根据条件从不同的数组中选择值。对于分段常数函数,这是一个简单的工具,因为它不需要构建任何附加函数(单线程)。
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(-1, 2)
y = np.select([x<0, x<1, x>1], [0.3, 0.5, 0.2])
fig, ax = plt.subplots()
ax.plot(x, y, clip_on = False)
plt.show()
所有情况下的输出:
如果您不想显示任何垂直线条,则可以根据条件绘制尽可能多的绘图。
import matplotlib.pyplot as plt
import numpy as np
l1 = lambda x: 0.3 + x * 0
l2 = lambda x: 0.2 + x * 0
l3 = lambda x: 0.5 + x * 0
x = np.linspace(-1, 2)
func = [l1,l2,l3]
cond = [x<0, x>=1, (x>=0) & (x<1)]
fig, ax = plt.subplots()
for f,c in zip(func,cond):
xi = x[c]
ax.plot(xi, f(xi), color="C0")
plt.show()
或者,使用numpy.select
,您可以修改x
数组,以确保包含值[0,1]
,这些值位于条件之间的边缘。选择明确排除这些值的条件,[x<0, (x>0) & (x<1), x>1]
(注意缺少任何等号)将允许将这些值设置为nan。未显示Nan值,因此出现间隙。
import matplotlib.pyplot as plt
import numpy as np
x = np.sort(np.append(np.linspace(-1, 2),[0,1]))
y = np.select([x<0, (x>0) & (x<1), x>1], [0.3, 0.5, 0.2], np.nan)
fig, ax = plt.subplots()
ax.plot(x, y)
plt.show()
答案 1 :(得分:0)
如果您只想为该示例执行此操作,则应该可以使用
import numpy as np
from matplotlib import pyplot as plt
x=np.linspace(-1,2)
plt.plot(x[x < 0], 0.3 + x[x < 0]*0, color='blue')
plt.plot(x[(x >= 0) & (x <= 1)], 0.5 + x[(x>=0) & (x<=1)]*0, color='blue')
plt.plot(x[x > 1],0.2 + x[x > 1]*0, color='blue')
plt.show()
当谈到离散函数时,掩蔽是我的方法。