以下代码生成2个图,在Jupyter笔记本上的另一个图上。如何使函数plotAllDist(...)从plot1Dist(...)接收图并将它们绘制为子图,而不是并排绘制?
我尝试阅读一些帖子,但无济于事...
def plot1Dist(x, sigmaJ, pmf, title):
fig = plt.figure()
freqTable = np.array(np.unique(sigmaJ, return_counts=True)).T
simu = plt.plot(freqTable[:,0], freqTable[:,1], label='Simulation')
dist = pmf * sum(freqTable[:,1])
model = plt.plot(x, dist, label='Model')
# add description to the plot
plt.legend(loc="upper right")
plt.xlabel('sigma')
plt.ylabel('Frequency')
plt.xticks(np.arange(min(x), max(x)+1, 5))
plt.title(title)
plt.show()
def plotAllDist(x, sigmaJ, e, pmf0, pmf1, FLIPS):
ONES = [i for i in range(e.size) if e[i] == 1]
ZEROS = [j for j in range(e.size) if e[j] == 0]
if (FLIPS == 0):
title = 'Distribution of sigma before bit-flipping'
elif (FLIPS > 0):
title = 'Distribution of sigma after ' + str(FLIPS) + ' flips'
plot1Dist(x, sigmaJ[ZEROS], pmf0, title + ' (e=0)')
plot1Dist(x, sigmaJ[ONES], pmf1, title + ' (e=1)')
答案 0 :(得分:0)
您应该使用plt.subplots
和object-oriented API for matplotlib。首先,避免在plot1Dist
内创建新图形,因为这每次都会强制创建新图形(因此,在笔记本中,在屏幕上将它们绘制在另一个图形之上)。
我会这样修改您的代码:
def plot1Dist(ax, x, sigmaJ, pmf, title):
"plots distribution onto Axes object `ax`."
# Just replace all instances of `plt` with `ax`. Their APIs aren't exactly
# the same but they are very similar
freqTable = np.array(np.unique(sigmaJ, return_counts=True)).T
simu = ax.plot(freqTable[:,0], freqTable[:,1], label='Simulation')
dist = pmf * sum(freqTable[:,1])
model = ax.plot(x, dist, label='Model')
# add description to the plot
ax.legend(loc="upper right")
ax.set_xlabel('sigma')
ax.set_ylabel('Frequency')
ax.xticks(np.arange(min(x), max(x)+1, 5))
ax.set_title(title)
def plotAllDist(x, sigmaJ, e, pmf0, pmf1, FLIPS):
ONES = [i for i in range(e.size) if e[i] == 1]
ZEROS = [j for j in range(e.size) if e[j] == 0]
if (FLIPS == 0):
title = 'Distribution of sigma before bit-flipping'
elif (FLIPS > 0):
title = 'Distribution of sigma after ' + str(FLIPS) + ' flips'
# number of columns = 2, so that the plots are side by side
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 7))
plot1Dist(axes[0], x, sigmaJ[ZEROS], pmf0, title + ' (e=0)')
plot1Dist(axes[1], x, sigmaJ[ONES], pmf1, title + ' (e=1)')
我真的鼓励您检查object-oriented API!在我看来,它比“ pyplot” API更加混乱和灵活。