有关DCGAN的两个问题:数据规范化和假/真实批处理

时间:2020-05-21 16:43:15

标签: deep-learning pytorch generative-adversarial-network dcgan

我正在分析在图像生成过程中使用DCGAN +爬行动物的元学习class

我对此代码有两个问题。

第一个问题:为什么在DCGAN培训期间(第74行)

training_batch = torch.cat ([real_batch, fake_batch])

training_batch是否由真实示例(real_batch)和伪示例(fake_batch)组成?为什么要通过混合真实图像和虚假图像来进行训练?我见过许多DCGAN,但从未接受过这种方式的培训。

第二个问题:为什么在训练过程中使用了normalize_data函数(第49行)和unnormalize_data函数(第55行)?

def normalize_data(data):
    data *= 2
    data -= 1
    return data


def unnormalize_data(data):
    data += 1
    data /= 2
    return data

该项目使用Mnist数据集,如果我想使用CIFAR10这样的颜色数据集,是否需要修改这些归一化?

2 个答案:

答案 0 :(得分:1)

训练GAN涉及给歧视者提供真实和虚假的例子。通常,您会看到它们分别在两个不同的场合给出。默认情况下,torch.cat在第一个维度(dim=0)上串联张量,第一个维度是批处理维度。因此,它只是使批处理大小增加了一倍,其中前半部分是真实图像,后半部分是伪图像。

为了计算损失,他们调整了目标,以便将前半部分(原始批次大小)分类为真实,而后半部分分类为伪造。来自initialize_gan

self.discriminator_targets = torch.tensor([1] * self.batch_size + [-1] * self.batch_size, dtype=torch.float, device=device).view(-1, 1)

图像的浮点值介于[0,1]之间。归一化会更改以产生[-1,1]之间的值。 GAN通常在生成器中使用tanh,因此伪图像的值在[-1,1]之间,因此真实图像应在相同范围内,否则鉴别者将伪图像与真实图像区分开会很简单。

如果要显示这些图像,则需要先对其进行归一化,即将其转换为[0,1]之间的值。

该项目使用Mnist数据集,如果我想使用CIFAR10这样的颜色数据集,是否需要修改这些归一化?

不,您不需要更改它们,因为彩色图像的值也介于[0,1]之间,因此存在更多的值,代表3个通道(RGB)。

答案 1 :(得分:0)

如果您仔细阅读了文档(请查看def initialize_gan(self):函数),将会发现

self.meta_g == Generator
self.meta_d == Discriminator

在您引用的行中,fake_batch被定义为Generator的一部分:

fake_batch = self.meta_g(torch.tensor(np.random.normal(size=(self.batch_size, self.z_shape)), dtype=torch.float, device=device))
training_batch = torch.cat([real_batch, fake_batch])

因此,因为它是GAN,所以您要给鉴别器提供假的和真实的图像,鉴别器必须弄清楚它是哪一个。

关于第二个问题,我想,但我不能完全确定这两个函数是否用于生成伪图像?我会仔细检查。

有帮助吗?