Python - 将残差添加到for循环生成的子图中

时间:2017-06-05 23:28:11

标签: python for-loop matplotlib subplot

在使用add_axes添加残差时,我无法使子图能够正常工作。它没有残差,效果很好,我可以在一个图中添加残差。这是我正在做的一个例子:

首先,只是为了让您了解我要绘制的内容,(t,y)是我想绘制的数据,拟合是否适合数据,而diff是拟合和数据之间的差异

t, s, fit = [], [], []
diff = []
for i in range(12):
    t.append(x / y[i])

    s.append(np.linspace(0, 1, num=100, endpoint=True))
    fit.append(UnivariateSpline(t[i], y, er, s=5e20))
    diff.append(fit[i](t[i]) - y)

这就是数字:

fig = plt.figure()
for i in range(12):
    plt.subplot(4,3,i+1)
    fig.add_axes((0.,0.3,0.7,0.9))
    plt.plot(s[i], fit[i](s[i]), 'r-') # this is the fit
    plt.errorbar(t[i], y, er, fmt='.k',ms=6) # this is the data 
    plt.axis([0,1, 190, 360])

    fig.add_axes((0.,0.,0.7,0.3))       
    plot(t[i],diff[i],'or') # this are the residuals
    plt.axis([0,1, 190, 360])

因此,你可以看到我生成了12个子图,直到我添加fig.add_axes来分隔数据+拟合和残差之间的每个子图,但是我得到的是一个杂乱的情节在子图上面(图已经缩小以查看下面的子图):

messy plot

我想要的是12个子图,每个子图看起来像这样:

correct plot

1 个答案:

答案 0 :(得分:1)

通常plt.subplot(..)fig.add_axes(..)是互补的。这意味着两个命令都在图中创建了一个轴。

然而,他们的用法会有所不同。要使用subplot创建12个子图,您可以

for i in range(12):
    plt.subplot(4,3,i+1)
    plt.plot(x[i],y[i])

要使用add_axes创建12个子图,您需要执行类似这样的操作

for i in range(12):
    ax = fig.add_axes([.1+(i%3)*0.8/3, 0.7-(i//3)*0.8/4, 0.2,.18])
    ax.plot(x[i],y[i])

其中轴的位置需要传递给add_axes

两者都很好。但是将它们组合起来并不是直截了当的,因为子图是根据网格定位的,而使用add_axes则需要知道网格位置。

所以我建议从头开始。创建子图的合理而干净的方法是使用plt.subplots()

fig, axes = plt.subplots(nrows=4, ncols=3)
for i, ax in enumerate(axes.flatten()):
    ax.plot(x[i],y[i])

每个子图可以使用轴分割器(make_axes_locatable

分成2个
from mpl_toolkits.axes_grid1 import make_axes_locatable
divider = make_axes_locatable(ax)
ax2 = divider.append_axes("bottom", size=size, pad=pad)
ax.figure.add_axes(ax2)

因此,在轴上循环并对每个轴执行上述操作可以获得所需的网格。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
plt.rcParams["font.size"] = 8

x = np.linspace(0,2*np.pi)
amp = lambda x, phase: np.sin(x-phase)
p = lambda x, m, n: m+x**(n)

fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(8,6), sharey=True, sharex=True)

def createplot(ax, x, m, n, size="20%", pad=0):
    divider = make_axes_locatable(ax)
    ax2 = divider.append_axes("bottom", size=size, pad=pad)
    ax.figure.add_axes(ax2)
    ax.plot(x, amp(x, p(x,m,n)))
    ax2.plot(x, p(x,m,n), color="crimson")
    ax.set_xticks([])

for i in range(axes.shape[0]):
    for j in range(axes.shape[1]):
        phase = i*np.pi/2
        createplot(axes[i,j], x, i*np.pi/2, j/2.,size="36%")

plt.tight_layout()        
plt.show()

enter image description here