Matplotlib:在plt.subplots中绘制多个直方图

时间:2019-02-18 15:48:31

标签: python python-3.x matplotlib

问题背景:

我正在研究一个类,该类将轴对象作为构造函数参数,并生成一个(m,n)尺寸图,每个图中均带有直方图单元格,如下图所示:

enter image description here

这里有两点要注意,不允许以任何方式进行修改:

  1. Figure对象未通过作为构造函数参数;只有轴对象。因此,不能以任何方式修改 subplots对象。
  2. 默认情况下,“轴”参数设置为(1,1)图形的参数(如下所示)。使其成为(m,n)图形所需的所有修改都在该类内进行(在其方法之内)
_, ax = plt.subplots() # By default takes (1,1) dimension
cm = ClassName(model, ax=ax, histogram=True) # calling my class

我遇到的问题:

由于我想在每个单元格内绘制直方图,因此我决定通过遍历每个单元格并为每个单元格创建直方图来实现此目的。

results[col].hist(ax=self.ax[y,x], bins=bins)

但是,我无法以任何方式指定直方图的轴。这是因为传递的 Axes参数具有默认尺寸(1,1),因此无法索引。当我尝试这个时,我得到一个TypeError的提示。

TypeError: 'AxesSubplot' object is not subscriptable

考虑了所有这些之后,我想知道将直方图添加到父Axes对象的任何可能方式。感谢您的关注。

1 个答案:

答案 0 :(得分:1)

要求非常严格,可能不是最佳设计选择。因为您以后要在单个子图的位置上绘制多个子图,所以创建该单个子图仅是为了死掉,稍后再进行替换。

因此,您可以做的是获取传入的轴的位置,然后在该位置创建一个新的gridspec。然后删除原始轴,并在该新创建的gridspec中创建一组新轴。

以下将是一个示例。请注意,当前要求传入的轴为Subplot(与任何轴相对)。 它还将图的数量硬编码为2*2。在实际用例中,您可能会从传入的model中得出该数字。

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec

class ClassName():
    def __init__(self, model, ax=None, **kwargs):
        ax = ax or plt.gca()
        if not hasattr(ax, "get_gridspec"):
            raise ValueError("Axes needs to be a subplot")
        parentgs = ax.get_gridspec()
        q = ax.get_geometry()[-1]

        # Geometry of subplots
        m, n = 2, 2
        gs = gridspec.GridSpecFromSubplotSpec(m,n, subplot_spec=parentgs[q-1])
        fig = ax.figure
        ax.remove()

        self.axes = np.empty((m,n), dtype=object)
        for i in range(m):
            for j in range(n):
                self.axes[i,j] = fig.add_subplot(gs[i,j], label=f"{i}{j}")


    def plot(self, data):
        for ax,d in zip(self.axes.flat, data):
            ax.plot(d)


_, (ax,ax2) = plt.subplots(ncols=2) 
cm = ClassName("mymodel", ax=ax2) # calling my class
cm.plot(np.random.rand(4,10))

plt.show()

enter image description here