将每个单独的图像保存在网格图像中

时间:2019-11-25 03:01:15

标签: python image generator pytorch save-image

我想将每个单独的图像保存在网格图像中。我应该如何修改此代码?还是我想参考一个答案?

class Generator(nn.Module):
def __init__(self):
    super(Generator, self).__init__()

    self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)

    self.init_size = opt.img_size // 4  # Initial size before upsampling
    self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

    self.conv_blocks = nn.Sequential(
        nn.BatchNorm2d(128),
        nn.Upsample(scale_factor=2),
        nn.Conv2d(128, 128, 3, stride=1, padding=1),
        nn.BatchNorm2d(128, 0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Upsample(scale_factor=2),
        nn.Conv2d(128, 64, 3, stride=1, padding=1),
        nn.BatchNorm2d(64, 0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
        nn.Tanh(),
    )

..

generator = Generator()

..

def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""

    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))

    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data[:i], "images/%d.png" % batches_done, nrow=1, normalize=True)

0 个答案:

没有答案