我想绘制一个大的pandas MultiIndex DataFrame。一个最小的示例如下:
import pandas as pd
years = range(2015, 2018)
fields = range(4)
days = range(4)
bands = ['R', 'G', 'B']
index = pd.MultiIndex.from_product(
[years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
[days, bands], names=['day', 'band'])
df = pd.DataFrame(0, index=index, columns=columns)
df.loc[(2015,), (0,)] = 1
df.loc[(2016,), (1,)] = 1
df.loc[(2017,), (2,)] = 1
如果我使用plt.spy
进行绘制,则会得到:
但是,刻度位置和标签不理想。我希望刻度线完全忽略MultiIndex的第二级。使用IndexLocator
和IndexFormatter
,我可以执行以下操作:
from matplotlib.ticker import IndexFormatter, IndexLocator
import matplotlib.pyplot as plt
ax = plt.gca()
plt.spy(df)
xbase = len(bands)
xoffset = xbase / 2
xlabels = df.columns.get_level_values('day')
ax.xaxis.set_major_locator(IndexLocator(base=xbase, offset=xoffset))
ax.xaxis.set_major_formatter(IndexFormatter(xlabels))
plt.xlabel('Day')
ax.xaxis.tick_bottom()
ybase = len(fields)
yoffset = ybase / 2
ylabels = df.index.get_level_values('year')
ax.yaxis.set_major_locator(IndexLocator(base=ybase, offset=yoffset))
ax.yaxis.set_major_formatter(IndexFormatter(ylabels))
plt.ylabel('Year')
plt.show()
这正是我想要的:
但这是问题所在。我的实际DataFrame有15年,4,000个字段,365天和7个带。如果我实际上每天都贴标签,那么标签将难以辨认。我可以每隔50天放置一个刻度,但是我希望刻度是动态的,这样当我放大时,刻度会变得更细。基本上,我要寻找的是一个自定义MultiIndexLocator
,它将IndexLocator
的位置与MaxNLocator
的动态结合在一起。
奖金:在每年总是有相同数量的字段并且每天都有相同数量的频段的意义上,我的数据确实很棒。但是,如果不是这种情况怎么办?我很乐意为matplotlib
贡献一个适用于任何MultiIndex DataFrame的通用MultiIndexLocator
和MultiIndexFormatter
。
答案 0 :(得分:1)
Matplotlib不了解数据帧或MultiIndex。它只是绘制您提供的数据。即您将获得与绘制numpy数据数组spy(df.values)
相同的效果。
因此,我建议您首先正确设置图像的范围,以便可以使用数字行情指示器。然后MaxNLocator
应该可以正常工作,除非您不要放大太多。
import numpy as np
import pandas as pd
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
plt.rcParams['axes.formatter.useoffset'] = False
years = range(2000, 2018)
fields = range(9) #17
days = range(120) #365
bands = ['R', 'G', 'B', 'A']
index = pd.MultiIndex.from_product(
[years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
[days, bands], names=['day', 'band'])
data = np.random.rand(len(years)*len(fields),len(days)*len(bands))
x,y = np.meshgrid(np.arange(data.shape[1]),np.arange(data.shape[0]))
data += 2*((y//len(fields)+x//len(bands)) % 2)
df = pd.DataFrame(data, index=index, columns=columns)
############
# Plotting
############
xbase = len(bands)
xlabels = df.columns.get_level_values('day')
ybase = len(fields)
ylabels = df.index.get_level_values('year')
extent = [xlabels.min()-np.diff(np.unique(xlabels))[0]/2.,
xlabels.max()+np.diff(np.unique(xlabels))[0]/2.,
ylabels.min()-np.diff(np.unique(ylabels))[0]/2.,
ylabels.max()+np.diff(np.unique(ylabels))[0]/2.,]
fig, ax = plt.subplots()
ax.imshow(df.values, extent=extent, aspect="auto")
ax.set_ylabel('Year')
ax.set_xlabel('Day')
ax.xaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
ax.yaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
plt.show()