并排绘制Python图

时间:2020-02-06 07:26:49

标签: python matplotlib jupyter-notebook

以下代码生成2个图,在Jupyter笔记本上的另一个图上。如何使函数plotAllDist(...)从plot1Dist(...)接收图并将它们绘制为子图,而不是并排绘制?

我尝试阅读一些帖子,但无济于事...

def plot1Dist(x, sigmaJ, pmf, title):
    fig = plt.figure()

    freqTable = np.array(np.unique(sigmaJ, return_counts=True)).T
    simu = plt.plot(freqTable[:,0], freqTable[:,1], label='Simulation')
    dist = pmf * sum(freqTable[:,1])
    model = plt.plot(x, dist, label='Model')

    # add description to the plot
    plt.legend(loc="upper right")
    plt.xlabel('sigma')
    plt.ylabel('Frequency')
    plt.xticks(np.arange(min(x), max(x)+1, 5))
    plt.title(title)
    plt.show()

def plotAllDist(x, sigmaJ, e, pmf0, pmf1, FLIPS):    
    ONES = [i for i in range(e.size) if e[i] == 1]
    ZEROS = [j for j in range(e.size) if e[j] == 0]

    if (FLIPS == 0):
        title = 'Distribution of sigma before bit-flipping'
    elif (FLIPS > 0):
        title = 'Distribution of sigma after ' + str(FLIPS) + ' flips'

    plot1Dist(x, sigmaJ[ZEROS], pmf0, title + ' (e=0)')
    plot1Dist(x, sigmaJ[ONES], pmf1, title + ' (e=1)')

1 个答案:

答案 0 :(得分:0)

您应该使用plt.subplotsobject-oriented API for matplotlib。首先,避免在plot1Dist内创建新图形,因为这每次都会强制创建新图形(因此,在笔记本中,在屏幕上将它们绘制在另一个图形之上)。

我会这样修改您的代码:

def plot1Dist(ax, x, sigmaJ, pmf, title):
    "plots distribution onto Axes object `ax`."

    # Just replace all instances of `plt` with `ax`. Their APIs aren't exactly
    # the same but they are very similar
    freqTable = np.array(np.unique(sigmaJ, return_counts=True)).T
    simu = ax.plot(freqTable[:,0], freqTable[:,1], label='Simulation')
    dist = pmf * sum(freqTable[:,1])
    model = ax.plot(x, dist, label='Model')

    # add description to the plot
    ax.legend(loc="upper right")
    ax.set_xlabel('sigma')
    ax.set_ylabel('Frequency')
    ax.xticks(np.arange(min(x), max(x)+1, 5))
    ax.set_title(title)

def plotAllDist(x, sigmaJ, e, pmf0, pmf1, FLIPS):
    ONES = [i for i in range(e.size) if e[i] == 1]
    ZEROS = [j for j in range(e.size) if e[j] == 0]

    if (FLIPS == 0):
        title = 'Distribution of sigma before bit-flipping'
    elif (FLIPS > 0):
        title = 'Distribution of sigma after ' + str(FLIPS) + ' flips'

    # number of columns = 2, so that the plots are side by side
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 7))
    plot1Dist(axes[0], x, sigmaJ[ZEROS], pmf0, title + ' (e=0)')
    plot1Dist(axes[1], x, sigmaJ[ONES], pmf1, title + ' (e=1)')

我真的鼓励您检查object-oriented API!在我看来,它比“ pyplot” API更加混乱和灵活。