matplotlib 基于标志列的线图段颜色

时间:2021-06-17 17:00:48

标签: pandas matplotlib

我有这个数据:

sale = [10, 20, 30, 40, 43, 46, 49, 50, 60, 70, 80, 90, 100, 110, 120, 130]
season = ['Winter'] * 7 + ['Spring'] * 3 + ['Summer'] * 3 + ['Fall'] * 3
ind = pd.concat([pd.DataFrame(pd.date_range(start='2020-1-1', periods=7, freq='W')),
                 pd.DataFrame(pd.date_range(start='2020-4-1', periods=9, freq='MS'))]).values.reshape((16,))

df = pd.DataFrame({
    'Sale': sale,
    'Season': season }, 
    index=ind,
)

即:

            Sale    Season
2020-01-05  10      Winter
2020-01-12  20      Winter
2020-01-19  30      Winter
2020-01-26  40      Winter
2020-02-02  43      Winter
2020-02-09  46      Winter
2020-02-16  49      Winter
2020-04-01  50      Spring
2020-05-01  60      Spring
2020-06-01  70      Spring
2020-07-01  80      Summer
2020-08-01  90      Summer
2020-09-01  100     Summer
2020-10-01  110     Fall
2020-11-01  120     Fall
2020-12-01  130     Fall

还有这张颜色图:

colors_map = {'Winter': 'b',
              'Spring': 'pink',
              'Summer': 'y',
              'Fall': 'orange'}

我可以轻松地绘制一条线,如下所示:

df.plot();

或绘制散点图如下:

plt.scatter(x=df.index, y=df['Sale'], c=df['Season'].map(colors_map))

但是,我不知道如何绘制一条线,但是基于颜色图的每个线段都有不同的颜色

这似乎是一个类似的问题: Plotting multiple segments with colors based on some variable with matplotlib

2 个答案:

答案 0 :(得分:2)

我会在每个季节绘制一列,您可以使用 pivotunstack

>>> sales = df.set_index('Season', append=True)['Sale']
>>> data = sales.unstack('Season')
>>> data
Season       Fall  Spring  Summer  Winter
2020-01-01    NaN     NaN     NaN    10.0
2020-02-01    NaN     NaN     NaN    20.0
2020-03-01    NaN     NaN     NaN    30.0
2020-04-01    NaN    40.0     NaN     NaN
2020-05-01    NaN    50.0     NaN     NaN
2020-06-01    NaN    60.0     NaN     NaN
2020-07-01    NaN     NaN    70.0     NaN
2020-08-01    NaN     NaN    80.0     NaN
2020-09-01    NaN     NaN    90.0     NaN
2020-10-01  100.0     NaN     NaN     NaN
2020-11-01  110.0     NaN     NaN     NaN
2020-12-01  120.0     NaN     NaN     NaN

调用这个新的数据框 data,然后你可以简单地用:

data.plot(color=colors_map)

结果如下:

plot results

这给出了季节之间的差距,但比您链接的另一个问题要简单得多。

某些选项可能会减少您的差距的影响,并真正表明每个“点”实际上是整整一个月:

data.plot(color=colors_map, drawstyle='steps-pre')

steps plot

如果这不满足您,您需要在 2 个不同列的边界处复制点:

首先让我们选择要填充的值,确保列的顺序合理:

>>> fillin = data.mask(data.isna() == data.isna().shift())
>>> fillin = fillin.reindex(['Winter', 'Spring', 'Summer', 'Fall'], axis='columns')
>>> fillin
Season      Winter  Spring  Summer   Fall
index                                    
2020-01-01    10.0     NaN     NaN    NaN
2020-02-01     NaN     NaN     NaN    NaN
2020-03-01     NaN     NaN     NaN    NaN
2020-04-01     NaN    40.0     NaN    NaN
2020-05-01     NaN     NaN     NaN    NaN
2020-06-01     NaN     NaN     NaN    NaN
2020-07-01     NaN     NaN    70.0    NaN
2020-08-01     NaN     NaN     NaN    NaN
2020-09-01     NaN     NaN     NaN    NaN
2020-10-01     NaN     NaN     NaN  100.0
2020-11-01     NaN     NaN     NaN    NaN
2020-12-01     NaN     NaN     NaN    NaN

现在通过旋转列将这些值填充到 data 中:

>>> fillin.shift(-1, axis='columns').assign(Fall=fillin['Winter'])
Season      Winter  Spring  Summer  Fall
index                                   
2020-01-01     NaN     NaN     NaN  10.0
2020-02-01     NaN     NaN     NaN   NaN
2020-03-01     NaN     NaN     NaN   NaN
2020-04-01    40.0     NaN     NaN   NaN
2020-05-01     NaN     NaN     NaN   NaN
2020-06-01     NaN     NaN     NaN   NaN
2020-07-01     NaN    70.0     NaN   NaN
2020-08-01     NaN     NaN     NaN   NaN
2020-09-01     NaN     NaN     NaN   NaN
2020-10-01     NaN     NaN   100.0   NaN
2020-11-01     NaN     NaN     NaN   NaN
2020-12-01     NaN     NaN     NaN   NaN
>>> data.fillna(fillin.shift(-1, axis='columns').assign(Fall=fillin['Winter'])).plot(color=colors_map)

enter image description here

这是使用您帖子中的新数据后的最终结果 - 我的代码保持不变:

enter image description here

答案 1 :(得分:1)

我相信重塑是可行的方法,因为它仅用于绘图,但是如果您想要一种不重塑的方法,则可以执行 for 循环并获取每个季节(每年)并将它们独立地绘制在同一个图上。请注意,loc 包括两个边界,因此在选择 Winter 时您将获得 Spring 的第一个元素,以便能够获得连续图。

import matplotlib.patches as mpatches

# get index change season
season_changed = df.index[df['Season'].ne(df['Season'].shift())].tolist()

# Create the figure
fig, ax = plt.subplots()
# iterate over each season - year
for start, end, season in zip(season_changed, 
                              season_changed[1:]+[df.index[-1]], 
                              df.loc[season_changed, 'Season']):
    df.loc[start:end, 'Sale'].plot(ax=ax, c=colors_map[season])

# define the legend
handles = [mpatches.Patch(color=val, label=key) 
           for key, val in colors_map.items()]
plt.legend(handles=handles, loc='best')

plt.plot()