Matplotlib:熊猫MultiIndex DataFrame的自定义代码

时间:2018-08-18 14:08:15

标签: python pandas matplotlib multi-index ticker

我想绘制一个大的pandas MultiIndex DataFrame。一个最小的示例如下:

import pandas as pd

years = range(2015, 2018)
fields = range(4)
days = range(4)
bands = ['R', 'G', 'B']

index = pd.MultiIndex.from_product(
    [years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
    [days, bands], names=['day', 'band'])

df = pd.DataFrame(0, index=index, columns=columns)

df.loc[(2015,), (0,)] = 1
df.loc[(2016,), (1,)] = 1
df.loc[(2017,), (2,)] = 1

如果我使用plt.spy进行绘制,则会得到:

simple plot

但是,刻度位置和标签不理想。我希望刻度线完全忽略MultiIndex的第二级。使用IndexLocatorIndexFormatter,我可以执行以下操作:

from matplotlib.ticker import IndexFormatter, IndexLocator

import matplotlib.pyplot as plt

ax = plt.gca()
plt.spy(df)

xbase = len(bands)
xoffset = xbase / 2
xlabels = df.columns.get_level_values('day')
ax.xaxis.set_major_locator(IndexLocator(base=xbase, offset=xoffset))
ax.xaxis.set_major_formatter(IndexFormatter(xlabels))
plt.xlabel('Day')
ax.xaxis.tick_bottom()

ybase = len(fields)
yoffset = ybase / 2
ylabels = df.index.get_level_values('year')
ax.yaxis.set_major_locator(IndexLocator(base=ybase, offset=yoffset))
ax.yaxis.set_major_formatter(IndexFormatter(ylabels))
plt.ylabel('Year')

plt.show()

这正是我想要的:

enter image description here

但这是问题所在。我的实际DataFrame有15年,4,000个字段,365天和7个带。如果我实际上每天都贴标签,那么标签将难以辨认。我可以每隔50天放置一个刻度,但是我希望刻度是动态的,这样当我放大时,刻度会变得更细。基本上,我要寻找的是一个自定义MultiIndexLocator,它将IndexLocator的位置与MaxNLocator的动态结合在一起。

奖金:在每年总是有相同数量的字段并且每天都有相同数量的频段的意义上,我的数据确实很棒。但是,如果不是这种情况怎么办?我很乐意为matplotlib贡献一个适用于任何MultiIndex DataFrame的通用MultiIndexLocatorMultiIndexFormatter

1 个答案:

答案 0 :(得分:1)

Matplotlib不了解数据帧或MultiIndex。它只是绘制您提供的数据。即您将获得与绘制numpy数据数组spy(df.values)相同的效果。

因此,我建议您首先正确设置图像的范围,以便可以使用数字行情指示器。然后MaxNLocator应该可以正常工作,除非您不要放大太多。

import numpy as np
import pandas as pd
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
plt.rcParams['axes.formatter.useoffset'] = False

years = range(2000, 2018)
fields = range(9) #17
days = range(120) #365
bands = ['R', 'G', 'B', 'A']

index = pd.MultiIndex.from_product(
    [years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
    [days, bands], names=['day', 'band'])

data = np.random.rand(len(years)*len(fields),len(days)*len(bands))
x,y = np.meshgrid(np.arange(data.shape[1]),np.arange(data.shape[0]))
data += 2*((y//len(fields)+x//len(bands)) % 2)
df = pd.DataFrame(data, index=index, columns=columns)

############
# Plotting
############

xbase = len(bands)
xlabels = df.columns.get_level_values('day')
ybase = len(fields)
ylabels = df.index.get_level_values('year')

extent = [xlabels.min()-np.diff(np.unique(xlabels))[0]/2.,
          xlabels.max()+np.diff(np.unique(xlabels))[0]/2.,
          ylabels.min()-np.diff(np.unique(ylabels))[0]/2.,
          ylabels.max()+np.diff(np.unique(ylabels))[0]/2.,]

fig, ax = plt.subplots()

ax.imshow(df.values, extent=extent, aspect="auto")
ax.set_ylabel('Year')
ax.set_xlabel('Day')

ax.xaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
ax.yaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))


plt.show()

enter image description here