matplotlib中的常见图例

时间:2016-06-09 10:40:14

标签: python matplotlib charts legend

我在matplotlib中绘制网格。我有一个像这样的3 * 2子图网格:

Image

我正在绘制每个子图中的折线图,并且该线的每种颜色都指定一个类别。 让我们说color_name = {cat1:red,cat2:black,...} 现在,颜色的含义在每个子图中都是相同的,但它们可能只包含color_name dict类别的子集。我需要一个包含颜色的绘图的常见图例 - 来自dict color_names的类别对。 我已经搜索了很多,但我没有得到一个方法直接使用dict制作图例,以便只包含一次dict中的每个键值对。 这里X轴包含日期,y值是每天。

代码:

def graph_data(data):
''' Return html for graph for one department '''

fig, axes = plt.subplots( nrows=3,
                          ncols=2,
                          sharex = True,
                          figsize=(14, 10),
                          facecolor = 'w' , 
                          )

rownames = sorted(data.keys())
beautify(fig, axes, rownames )

plt.close()

color_name = get_color_name_dict()

for i, metric in enumerate(sorted(data.keys())):
    for j, device in enumerate(sorted(data[metric].keys())):
        ax = axes[i , j]

        ylow = 10000
        yhigh = 0

        for page in data[metric][device].keys():
            dd = data[metric][device][page]

            x = dd.keys()

            y = [statify(dd[_date]) for _date in x]

            ylow = min(ylow , min(y))
            yhigh = max(yhigh, max(y))

            ax.xaxis.set_major_formatter(mdates.DateFormatter('%d-%m-%Y'))
            ax.xaxis.set_major_locator(mdates.DayLocator())

            ax.plot(y , color = color_name[page], label = page , linewidth = '2')

            ax.fmt_xdata = mdates.DateFormatter('%d-%m-%Y')

            for tick in ax.get_xticklabels():
                tick.set_rotation(45)


            fig.autofmt_xdate()

        ax.set_ylim([ylow-200 , yhigh+200])



fig.tight_layout()
# tight_layout doesn't take these labels into account. We'll need 
# to make some room. These numbers are are manually tweaked. 
fig.subplots_adjust(left=0.18, top=0.95, right=0.95, bottom = 0.05)

imgdata = cStringIO.StringIO()
fig.savefig(imgdata, format='png' , facecolor = fig.get_facecolor())

s = '<img alt = "embedded" src = "data:image/png;base64,%s"/>' % imgdata.getvalue().encode("base64").strip()

orig_stdout = sys.stdout
f = open("test.html" , "w")
sys.stdout = f

print s

sys.stdout = orig_stdout
f.close()

print "done"
mpld3.show(fig)
#return s

美化功能:

def beautify(fig, axes, rows):
''' Beautify the plot '''
style.use('ggplot')

cols = ['Desktop' , 'Mobile']

for ax in axes[:,0]:
    ax.set_ylabel("Median (ms)")




pad = 5 # in points

for ax, col_name in zip(axes[0], cols):
    ax.annotate(col_name, xy=(0.5, 1), xytext=(0, pad),
                xycoords='axes fraction', textcoords='offset points',
                size='large', ha='center', va='baseline')

for ax, row_name in zip(axes[:,0], rows):
    ax.annotate(row_name, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size='large', ha='right', va='center')

0 个答案:

没有答案