使用自定义处理程序时,忽略了matplotlib图例参数

时间:2018-06-21 12:18:12

标签: matplotlib legend

我正在尝试绘制一个自定义图例,其中一行上有一些半径不断增大的圆圈,后跟“收入”一词。我认为这是一种很好的方式来表明圆圈的大小与受试者的收入相对应。

图例必须手动绘制。这是我的实现方式:

class AnyObjectHandler(HandlerBase):
    def create_artists(self, legend, orig_handle,
                       x0, y0, width, height, fontsize, trans):

        legend.numpoints = 1
        l1 = plt.Line2D([x0 - 40, y0 + width], [0.3 * height, 0.3 * height], color='blue',
                        marker='o', markersize=10, markerfacecolor="blue")

        return [l1]

fig.legend([object], ['Income'], numpoints=1,
           handler_map={object: AnyObjectHandler()})

问题在于,即使我尝试两次指定numpoints == 1,图例仍然每行默认带有2个标记。一个相关的问题(我发现如何将numpoints设置为1的地方是:matplotlib Legend Markers Only Once

这是上面的代码产生的:

Currently produced legen

相反,我希望线条仅显示一个圆圈。没关系。

1 个答案:

答案 0 :(得分:2)

您可以在处理程序的Line2D方法内创建第二个create_artist。这允许在行中使用Line2D,在标记中使用另一个。

import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerBase

class AnyObjectHandler(HandlerBase):
    def create_artists(self, legend, orig_handle,
                       x0, y0, width, height, fontsize, trans):

        l1 = plt.Line2D([x0 - 40, x0 + width], [0.3 * height, 0.3 * height], 
                        color='blue', marker='None')

        m1 = plt.Line2D([x0 - 40], [0.3 * height], color='blue', linestyle="",
                        marker='o', markersize=10, markerfacecolor="blue")

        return [l1, m1]


fig, ax = plt.subplots()

fig.legend([object], ['Income'],
           handler_map={object: AnyObjectHandler()})

plt.show()

enter image description here

此解决方案与numpoints参数无关,并且在您已经知道只需要一个点的情况下可能很有用。

或者,您可以访问numpoints以指定您要使用的点数。最好通过对实际上知道numpoints参数的处理程序进行子类化来实现。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerNpoints

class AnyObjectHandler(HandlerNpoints):
    def create_artists(self, legend, orig_handle,
                       x0, y0, width, height, fontsize, trans):

        l1 = plt.Line2D([x0 - 40, x0 + width], [0.3 * height, 0.3 * height], 
                        color='blue', marker='None')

        num = self.get_numpoints(legend)
        if num == 1:
            xdata = [x0 - 40]
        else:
            xdata = np.linspace(x0 - 40, x0 + width, num)
        m1 = plt.Line2D(xdata, [0.3 * height]*len(xdata), color='blue', 
                        linestyle="", marker='o', markersize=10)

        return [l1, m1]


fig, ax = plt.subplots()

fig.legend([object], ['Income'], numpoints=1,
           handler_map={object: AnyObjectHandler()})

plt.show()

对于numpoints=1,其结果与上述相同,但是您可以指定numpoints=2

enter image description here

numpoints=3

enter image description here