如何使用Matplotlib为所有子图设置相同的y轴比例?

时间:2020-10-26 15:59:51

标签: python matplotlib

我正在绘制显示相同度量的图表,但希望将y轴比例设置为相同的值。我该怎么办?

在下面的图片中,我显示一个示例:

功效为e ^ -3,但必须为-4(将y值乘以10)。而且为了准确起见,也可以将小数位数设置为相同。

enter image description here

代码是:

def plot_loss(tr, te, _label):
    x = np.linspace(1, 50, 50)

    plt.errorbar(x, np.mean(np.array(tr), axis=0), yerr=np.std(np.array(tr), axis=0), label=_label[0] + '_loss_tr')
    plt.errorbar(x, np.mean(np.array(te), axis=0), yerr=np.std(np.array(te), axis=0), label=_label[0] + '_loss_vl')

    box = ax.get_position()
    ax.set_position([box.x0, box.y0 - 0.1, box.width, box.height])
    ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True, useOffset=False))

    plt.xlabel('Epoch', weight='bold', size=9)
    plt.ylabel('Mean Square Error', weight='bold', size=9)
    plt.ticklabel_format(axis="y", style="sci", scilimits=(-3, 0))
    plt.title('Training Loss %s' % _label.split('_')[1].upper(), weight='bold', size=10)
    plt.tight_layout()

def plot_acc(tr, te, _label):
    x = np.linspace(1, 50, 50)
    plt.errorbar(x, np.mean(np.array(tr), axis=0), yerr=np.std(np.array(tr), axis=0), label=_label[0] + '_acc_tr')
    plt.errorbar(x, np.mean(np.array(te), axis=0), yerr=np.std(np.array(te), axis=0), label=_label[0] + '_acc_vl')

    box = ax.get_position()
    ax.set_position([box.x0, box.y0 - 0.1, box.width, box.height])

    plt.xlabel('Epoch', weight='bold', size=9)
    plt.ylabel('Balanced Accuracy Score', weight='bold', size=9)
    plt.title('Classification Accuracy %s' % _label.split('_')[1].upper(), weight='bold', size=10)
    plt.tight_layout()

fig = plt.figure(figsize=(14, 6.5))
for i, example in enumerate([('i_y', 'r_y'), ('i_a', 'r_a'), ('i_b', 'r_b'), ('i_n', 'r_n')]):
    for label in example:
        loss_tr, loss_te, acc_tr, acc_te = get_history(path.join('../results/history/', label))

        ax = fig.add_subplot(2, 4, i + 1)
        plot_loss(loss_tr, loss_te, label)
        handles, labels = ax.get_legend_handles_labels()

        if i + 1 == 4:
            fig.legend(handles, labels, loc='lower center', ncol=4, bbox_to_anchor=(0.5, 0.52))

        ax = fig.add_subplot(2, 4, i + 5)
        plot_acc(acc_tr, acc_te, label)
        handles, labels = ax.get_legend_handles_labels()

        if i + 5 == 8:
            fig.legend(handles, labels, loc='lower center', ncol=4, bbox_to_anchor=(0.5, 0.05))

fig.subplots_adjust(hspace=0.7)

0 个答案:

没有答案