用功能生成的图形填充matplotlib子图

时间:2018-07-17 15:46:29

标签: python matplotlib subplot multiple-axes

我已经找到并修改了以下代码段,以生成用于线性回归的诊断图。目前,这是使用以下功能完成的:

def residual_plot(some_values):
    plot_lm_1 = plt.figure(1)    
    plot_lm_1 = sns.residplot()
    plot_lm_1.axes[0].set_title('title')
    plot_lm_1.axes[0].set_xlabel('label')
    plot_lm_1.axes[0].set_ylabel('label')
    plt.show()


def qq_plot(residuals):
    QQ = ProbPlot(residuals)
    plot_lm_2 = QQ.qqplot()    
    plot_lm_2.axes[0].set_title('title')
    plot_lm_2.axes[0].set_xlabel('label')
    plot_lm_2.axes[0].set_ylabel('label')
    plt.show()

以类似这样的方式调用:

plot1 = residual_plot(value_set1)
plot2 = qq_plot(value_set1)
plot3 = residual_plot(value_set2)
plot4 = qq_plot(value_set2)

如何创建subplots以便将这4个图显示在2x2网格中?
我尝试使用:

fig, axes = plt.subplots(2,2)
    axes[0,0].plot1
    axes[0,1].plot2
    axes[1,0].plot3
    axes[1,1].plot4
    plt.show()

但收到错误:

AttributeError: 'AxesSubplot' object has no attribute 'plot1'

我应该在函数内部还是在其他地方设置轴属性?

1 个答案:

答案 0 :(得分:2)

您应该创建一个具有四个子图轴的图形,然后将其用作自定义图函数的输入轴

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import probplot


def residual_plot(x, y, axes = None):
    if axes is None:
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 1, 1)
    else:
        ax1 = axes
    p = sns.residplot(x, y, ax = ax1)
    ax1.set_xlabel("Data")
    ax1.set_ylabel("Residual")
    ax1.set_title("Residuals")
    return p


def qq_plot(x, axes = None):
    if axes is None:
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 1, 1)
    else:
        ax1 = axes
    p = probplot(x, plot = ax1)
    ax1.set_xlim(-3, 3)
    return p


if __name__ == "__main__":
    # Generate data
    x = np.arange(100)
    y = 0.5 * x
    y1 = y + np.random.randn(100)
    y2 = y + np.random.randn(100)

    # Initialize figure and axes
    fig = plt.figure(figsize = (8, 8), facecolor = "white")
    ax1 = fig.add_subplot(2, 2, 1)
    ax2 = fig.add_subplot(2, 2, 2)
    ax3 = fig.add_subplot(2, 2, 3)
    ax4 = fig.add_subplot(2, 2, 4)

    # Plot data
    p1 = residual_plot(y, y1, ax1)
    p2 = qq_plot(y1, ax2)
    p3 = residual_plot(y, y2, ax3)
    p4 = qq_plot(y2, ax4)

    fig.tight_layout()
    fig.show()

我不知道您的ProbPlot功能是什么,所以我只选了SciPy的功能。