如何删除matplotlib.pyplot中子图之间的空间?

时间:2016-12-10 03:45:56

标签: python numpy matplotlib

我正在开展一个项目,我需要将10行3列的绘图网格组合在一起。虽然我已经能够绘制图并安排子图,但是我无法生成没有空格的漂亮图,例如gridspec documentatation下面的image w/o white spaceMatplotlib different size subplots

我尝试了以下帖子,但仍然无法像示例图片中那样完全删除空白区域。有人可以给我一些指导吗?谢谢!

这是我的形象:The full script is here on GitHub

以下是我的代码。 {{3}}。 注意:images_2和images_fool都是具有形状(1032,10)的扁平图像的numpy数组,而delta是形状的图像数组(28,28)。

def plot_im(array=None, ind=0):
    """A function to plot the image given a images matrix, type of the matrix: \
    either original or fool, and the order of images in the matrix"""
    img_reshaped = array[ind, :].reshape((28, 28))
    imgplot = plt.imshow(img_reshaped)

# Output as a grid of 10 rows and 3 cols with first column being original, second being
# delta and third column being adversaril
nrow = 10
ncol = 3
n = 0

from matplotlib import gridspec
fig = plt.figure(figsize=(30, 30)) 
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1]) 

for row in range(nrow):
    for col in range(ncol):
        plt.subplot(gs[n])
        if col == 0:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_2, ind=row)
        elif col == 1:
            #plt.subplot(nrow, ncol, n)
            plt.imshow(w_delta)
        else:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_fool, ind=row)
        n += 1

plt.tight_layout()
#plt.show()
plt.savefig('grid_figure.pdf')

3 个答案:

答案 0 :(得分:19)

开头的注释:如果您想要完全控制间距,请避免使用plt.tight_layout(),因为它会尝试将图中的图形排列为同样且分布均匀。这大部分都很好并且会产生令人愉快的结果,但会随意调整间距。

您从Matplotlib示例库中引用的GridSpec示例运行良好的原因是因为子图“'方面未预定义。也就是说,子图将简单地在网格上展开并保留设置的间距(在这种情况下为wspace=0.0, hspace=0.0),与图形大小无关。

与使用imshow绘制图像相反,默认情况下图像的方面设置相等(相当于ax.set_aspect("equal"))。也就是说,您当然可以将set_aspect("auto")添加到每个绘图中(并且另外添加wspace=0.0, hspace=0.0作为GridSpec的参数,如在库示例中),这将生成没有间距的绘图。

然而,当使用图像时,保持相同的宽高比使得每个像素都宽到高并且正方形阵列显示为方形图像是很有意义的。
您需要做的是使用图像大小和图形边距来获得预期结果。图中的figsize参数是以英寸为单位的图形(宽度,高度),这里可以使用两个数字的比率。并且可以手动调整子图参数wspace, hspace, top, bottom, left以给出期望的结果。 以下是一个例子:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

nrow = 10
ncol = 3

fig = plt.figure(figsize=(4, 10)) 

gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1],
         wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845) 

for i in range(10):
    for j in range(3):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

#plt.tight_layout() # do not use this!!
plt.show()

enter image description here

修改
当然希望不必手动调整参数。因此,可以根据行数和列数计算一些最优的。

nrow = 7
ncol = 7

fig = plt.figure(figsize=(ncol+1, nrow+1)) 

gs = gridspec.GridSpec(nrow, ncol,
         wspace=0.0, hspace=0.0, 
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
         left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

for i in range(nrow):
    for j in range(ncol):
        im = np.random.rand(28,28)
        ax= plt.subplot(gs[i,j])
        ax.imshow(im)
        ax.set_xticklabels([])
        ax.set_yticklabels([])

plt.show()

答案 1 :(得分:5)

尝试将此行添加到您的代码中:

fig.subplots_adjust(wspace=0, hspace=0)

对于每个轴对象集:

ax.set_xticklabels([])
ax.set_yticklabels([])

答案 2 :(得分:0)

以下是ImportanceOfBeingErnest的答案,但是如果您想使用plt.subplots及其功能:

fig, axes = plt.subplots(
    nrow, ncol,
    gridspec_kw=dict(wspace=0.0, hspace=0.0,
                     top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
                     left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)),
    figsize=(ncol + 1, nrow + 1),
    sharey='row', sharex='col', #  optionally
)