如何与许多子图共享Y轴幅度?

时间:2017-06-13 15:00:34

标签: python pandas matplotlib

我的CSV包含三列:DATE,LOC,CNT(示例如下)。我想要很多子图(最终将有大约200个)由此构成(我正在思考迷你图大小的图,但我还没那么远)。我遇到的问题是我的情节并没有共享Y轴幅度,因此具有非常不同数量的组看起来相似。

import pandas as pd
import math
import matplotlib.pyplot as plt

dfh = pd.read_csv('testdata.csv')
num_sites = dfh['LOC'].nunique()

cols = 5
rows = int(math.ceil(float(num_sites) / cols))

fig, axs = plt.subplots(nrows=rows, ncols=cols)

grouped = dfh.groupby('LOC')

targets = zip(grouped.groups.keys(), axs.flatten())
for i, (key, ax) in enumerate(targets):
    ax.plot(grouped.get_group(key)['CNT'])

plt.show()

示例数据:

DATE,LOC,CNT
2017-06-01,Loc 1,1
2017-06-02,Loc 1,6
2017-06-03,Loc 1,4
2017-06-04,Loc 1,1
2017-06-05,Loc 1,1
2017-06-01,Loc 2,0
2017-06-02,Loc 2,7
2017-06-03,Loc 2,4
2017-06-04,Loc 2,10
2017-06-05,Loc 2,12
2017-06-01,Loc 3,5
2017-06-02,Loc 3,2
2017-06-03,Loc 3,1
2017-06-04,Loc 3,8
2017-06-05,Loc 3,1
2017-06-01,Loc 4,19
2017-06-02,Loc 4,20
2017-06-03,Loc 4,15
2017-06-04,Loc 4,12
2017-06-05,Loc 4,22
2017-06-01,Loc 5,0
2017-06-02,Loc 5,1
2017-06-03,Loc 5,1
2017-06-04,Loc 5,2
2017-06-05,Loc 5,2
2017-06-01,Loc 6,7
2017-06-02,Loc 6,5
2017-06-03,Loc 6,7
2017-06-04,Loc 6,5
2017-06-05,Loc 6,6

这产生了:

Sample Output

请注意,我的最大Y为7,2,22,8,12 6,然后是一些最多为1的空白图。

我的问题:如何让这些图中的每个图共享相同的Y轴? X轴也应该是相同的,但我认为这只会让我删除X轴标签,因为我已在数据中确认每个组具有相同的X点。

奖励:有没有办法在最后删除那些空白的情节?我不能保证我会有一组填满整行的情节。

1 个答案:

答案 0 :(得分:1)

要分享所有yaxes,请使用

fig, axs = plt.subplots(nrows=rows, ncols=cols, sharey=True)

x:sharex=True

要删除多余的轴,可以通过附加

将其关闭
for j in range(i+1, len(axs.flatten())):
    axs.flatten()[j].axis("off")

其中i是前一循环的循环变量。

一个完整的工作示例:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

dfh = pd.DataFrame({"LOC" : np.random.randint(0,6, size=100),
              "CNT" : np.arange(100)})

cols = 5
rows = 2

fig, axs = plt.subplots(nrows=rows, ncols=cols, sharey=True, sharex=True)

grouped = dfh.groupby('LOC')

targets = zip(grouped.groups.keys(), axs.flatten())
for i, (key, ax) in enumerate(targets):
    ax.plot(grouped.get_group(key)['CNT'])

for j in range(i+1, len(axs.flatten())):
    axs.flatten()[j].axis("off")

plt.show()

enter image description here