在matplotlib图例中插入图像

时间:2014-09-25 02:34:12

标签: python numpy matplotlib

我想在maplotlib图的图例中插入几个小图形(矢量图形,但必要时可以制作光栅)。图例中的每个项目都会有一个图形。

我知道我可以使用something like an annotation box手动绘制整个图例,但这看起来很乏味,图中的任何小变化都需要手工修复。

pyplot.plot来电pyplot.legend或更晚的电话中,有没有办法在标签中加入图片?

1 个答案:

答案 0 :(得分:10)

所以,下面有点hacky,但它可以让你在那里大部分时间。注意:您需要将[PATH TO IMAGE]替换为您想要的图像(否则您将免费获得Grace Hopper!)。您还可以通过传递image_stretch参数使图像大于默认值。这是修复图像宽高比的黑客方法。如果图像从一个系列重叠到下一个系列,请使用labelspacing参数。

import os

from matplotlib.transforms import TransformedBbox
from matplotlib.image import BboxImage
from matplotlib.legend_handler import HandlerBase
from matplotlib._png import read_png

class ImageHandler(HandlerBase):
    def create_artists(self, legend, orig_handle,
                       xdescent, ydescent, width, height, fontsize,
                       trans):

        # enlarge the image by these margins
        sx, sy = self.image_stretch 

        # create a bounding box to house the image
        bb = Bbox.from_bounds(xdescent - sx,
                              ydescent - sy,
                              width + sx,
                              height + sy)

        tbb = TransformedBbox(bb, trans)
        image = BboxImage(tbb)
        image.set_data(self.image_data)

        self.update_prop(image, orig_handle, legend)

        return [image]

    def set_image(self, image_path, image_stretch=(0, 0)):
        if not os.path.exists(image_path):
            sample = get_sample_data("grace_hopper.png", asfileobj=False)
            self.image_data = read_png(sample)
        else:
            self.image_data = read_png(image_path)

        self.image_stretch = image_stretch

# random data
x = np.random.randn(100)
y = np.random.randn(100)
y2 = np.random.randn(100)

# plot two series of scatter data
s = plt.scatter(x, y, c='b')
s2 = plt.scatter(x, y2, c='r')

# setup the handler instance for the scattered data
custom_handler = ImageHandler()
custom_handler.set_image("[PATH TO IMAGE]",
                         image_stretch=(0, 20)) # this is for grace hopper

# add the legend for the scattered data, mapping the
# scattered points to the custom handler
plt.legend([s, s2],
           ['Scatters 1', 'Scatters 2'],
           handler_map={s: custom_handler, s2: custom_handler},
           labelspacing=2,
           frameon=False)

这是它产生的东西:

grace hopper