我正在分析在图像生成过程中使用DCGAN +爬行动物的元学习class。
我对此代码有两个问题。
第一个问题:为什么在DCGAN培训期间(第74行)
training_batch = torch.cat ([real_batch, fake_batch])
training_batch是否由真实示例(real_batch)和伪示例(fake_batch)组成?为什么要通过混合真实图像和虚假图像来进行训练?我见过许多DCGAN,但从未接受过这种方式的培训。
第二个问题:为什么在训练过程中使用了normalize_data函数(第49行)和unnormalize_data函数(第55行)?
def normalize_data(data):
data *= 2
data -= 1
return data
def unnormalize_data(data):
data += 1
data /= 2
return data
该项目使用Mnist数据集,如果我想使用CIFAR10这样的颜色数据集,是否需要修改这些归一化?
答案 0 :(得分:1)
训练GAN涉及给歧视者提供真实和虚假的例子。通常,您会看到它们分别在两个不同的场合给出。默认情况下,torch.cat
在第一个维度(dim=0
)上串联张量,第一个维度是批处理维度。因此,它只是使批处理大小增加了一倍,其中前半部分是真实图像,后半部分是伪图像。
为了计算损失,他们调整了目标,以便将前半部分(原始批次大小)分类为真实,而后半部分分类为伪造。来自initialize_gan
:
self.discriminator_targets = torch.tensor([1] * self.batch_size + [-1] * self.batch_size, dtype=torch.float, device=device).view(-1, 1)
图像的浮点值介于[0,1]之间。归一化会更改以产生[-1,1]之间的值。 GAN通常在生成器中使用tanh,因此伪图像的值在[-1,1]之间,因此真实图像应在相同范围内,否则鉴别者将伪图像与真实图像区分开会很简单。
如果要显示这些图像,则需要先对其进行归一化,即将其转换为[0,1]之间的值。
该项目使用Mnist数据集,如果我想使用CIFAR10这样的颜色数据集,是否需要修改这些归一化?
不,您不需要更改它们,因为彩色图像的值也介于[0,1]之间,因此存在更多的值,代表3个通道(RGB)。
答案 1 :(得分:0)
如果您仔细阅读了文档(请查看def initialize_gan(self):
函数),将会发现
self.meta_g == Generator
self.meta_d == Discriminator
在您引用的行中,fake_batch被定义为Generator的一部分:
fake_batch = self.meta_g(torch.tensor(np.random.normal(size=(self.batch_size, self.z_shape)), dtype=torch.float, device=device))
training_batch = torch.cat([real_batch, fake_batch])
因此,因为它是GAN,所以您要给鉴别器提供假的和真实的图像,鉴别器必须弄清楚它是哪一个。
关于第二个问题,我想,但我不能完全确定这两个函数是否用于生成伪图像?我会仔细检查。
有帮助吗?