在多个matplotlib子图上共享两个y轴

时间:2018-01-17 14:25:54

标签: python matplotlib plot

我的目标是使用matplotlib获得两行和三列图。顶行中的每个图形将包含两个数据系列和两个y轴。我想让每个轴上的刻度线对齐,以便相应的数据系列可以直接比较。现在我有了它,以便每个图形上的主要y轴对齐,但我不能让次要的y轴对齐。这是我目前的代码:

import matplotlib.pyplot as plt
import pandas as pd


excel_file = 'test_data.xlsx'
sims = ['Sim 02', 'Sim 01',  'Sim 03']

if __name__ == '__main__':
    data = pd.read_excel(excel_file, skiprows=[0, 1, 2, 3], sheetname=None, header=1, index_col=[0, 1], skip_footer=10)
    plot_cols = len(sims)
    plot_rows = 2

    f, axes = plt.subplots(plot_rows, plot_cols, sharex='col', sharey='row')
    secondary_ax = []

    for i, sim in enumerate(sims):
        df = data[sim]
        modern = df.loc['Modern']
        traditional = df.loc['Traditional']
        axes[0][i].plot(modern.index, modern['Idle'])
        secondary_ax.append(axes[0][i].twinx())
        secondary_ax[i].plot(modern.index, modern['Work'])
        axes[1][i].bar(modern.index, modern['Result'])
        axes[0][i].set_xlim(12, 6)
        if i > 0:
            secondary_ax[0].get_shared_y_axes().join(secondary_ax[0], secondary_ax[i])

    # secondary_ax[0].get_shared_y_axes().join(x for x in secondary_ax)
    plt.show()

我尝试的解决方案(if语句中的行和plt.show()之前的最后一行)是类似问题的解决方案,但它没有解决我的问题。没有什么破坏,次要轴没有对齐。

Results of current code

我还尝试在plt.subplots方法中添加一个额外的行,并使用twinx()组合前两行,但是它创建了一个空的第二行图表。

作为后退,我想我可以手动检查每个轴的最大值和分钟数,然后循环查看每个轴以手动更新,但我很想找到一个更清洁的解决方案,如果有人在那里,任何人都有任何见解。感谢。

1 个答案:

答案 0 :(得分:0)

您需要在绘制数据之前共享y轴:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


# excel_file = 'test_data.xlsx'
sims = ['Sim 02', 'Sim 01',  'Sim 03']

if __name__ == '__main__':
    #     data = pd.read_excel(excel_file, skiprows=[0, 1, 2, 3], sheetname=None, header=1, index_col=[0, 1], skip_footer=10)
    modern = pd.DataFrame(np.random.randint(0, 100, (100, 3)), columns=sims)
    traditional = pd.DataFrame(np.random.randint(10, 30, (100, 3)), columns=sims)
    traditional[sims[1]] = traditional[sims[1]] + 40
    traditional[sims[2]] = traditional[sims[2]] - 40
    data3 = pd.DataFrame(np.random.randint(0, 100, (100, 3)), columns=sims)

    plot_cols = len(sims)
    plot_rows = 2

    f, axes = plt.subplots(plot_rows, plot_cols, sharex='col', sharey='row', figsize=(30, 10))
    secondary_ax = []


    for i, sim in enumerate(sims):
        df = data[sim]
        modern_series = modern[sim]
        traditional_series = traditional[sim]
        idle = data3

        axes[0][i].plot(modern_series.index, modern_series)
        secondary_ax.append(axes[0][i].twinx())
        if i > 0:
            secondary_ax[0].get_shared_y_axes().join(secondary_ax[0], secondary_ax[i])

        secondary_ax[i].plot(traditional_series.index, traditional_series)
        #         axes[1][i].bar(data3.index, data3)
        axes[0][i].set_xlim(12, 6)

    plt.show()

enter image description here