在Python中保存异构图像网格

时间:2018-11-12 02:57:46

标签: python image numpy matplotlib scipy

如何使用下面的代码将图像保存在4x4的异构图像网格上?想象一下,图像是由sample [i]识别的,而i取16个不同的值。

scipy.misc.imsave(str(img_index) + '.png', sample[1])

类似于此答案,但适用于16个不同的图像 https://stackoverflow.com/a/42041135/2414957

我对所采用的方法没有偏见,只要能够做到。另外,我对保存图像感兴趣,而不是像使用远程服务器并处理CelebA图像数据集(这是一个巨大的数据集)一样,使用plt.show()显示它们。我只想从我的批次中随机选择16张图像并保存DCGAN的结果,看看它是否有意义或是否收敛。

*当前,我正在保存如下图像:

batch_no = random.randint(0, 63)


scipy.misc.imsave('sample_gan_images/iter_%d_epoch_%d_sample_%d.png' %(itr, epoch, batch_no), sample[batch_no])

在这里,我有25个时期和2000次迭代,批处理大小为64。

2 个答案:

答案 0 :(得分:1)

我个人倾向于在这种情况下使用matplotlib.pyplot.subplots。如果您的图像确实是异类的,那么它可能是比您链接到的答案中基于图像串联的方法更好的选择。

import matplotlib.pyplot as plt
from scipy.misc import face

x = 4
y = 4

fig,axarr = plt.subplots(x,y)
ims = [face() for i in range(x*y)]

for ax,im in zip(axarr.ravel(), ims):
    ax.imshow(im)

fig.savefig('faces.png')

default plt.subplots behavior

我对subplots的最大抱怨是结果图中空白的数量。同样,对于您的应用程序,您可能不希望轴刻度/框架。这是处理这些问题的包装器函数:

import matplotlib.pyplot as plt

def savegrid(ims, rows=None, cols=None, fill=True, showax=False):
    if rows is None != cols is None:
        raise ValueError("Set either both rows and cols or neither.")

    if rows is None:
        rows = len(ims)
        cols = 1

    gridspec_kw = {'wspace': 0, 'hspace': 0} if fill else {}
    fig,axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw)

    if fill:
        bleed = 0
        fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))

    for ax,im in zip(axarr.ravel(), ims):
        ax.imshow(im)
        if not showax:
            ax.set_axis_off()

    kwargs = {'pad_inches': .01} if fill else {}
    fig.savefig('faces.png', **kwargs)

在与先前使用的同一组图像上运行savegrid(ims, 4, 4)会产生以下结果:

output of savegrid wrapper

如果使用savegrid,如果希望每个图像占用更少的空间,请传递fill=False关键字arg。如果要显示轴刻度/帧,请传递showax=True

答案 1 :(得分:1)

我在github上找到了这个,也分享了它:

import matplotlib.pyplot as plt

def merge_images(image_batch, size):
    h,w = image_batch.shape[1], image_batch.shape[2]
    c = image_batch.shape[3]
    img = np.zeros((int(h*size[0]), w*size[1], c))
    for idx, im in enumerate(image_batch):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w,:] = im
    return img

im_merged = merge_images(sample, [8,8])
plt.imsave('sample_gan_images/im_merged.png', im_merged )

enter image description here