for epoch in range(num_epoch):
for i, (img, _) in enumerate(dataloader):
num_img = img.size(0)
# =================train discriminator
img = img.view(num_img, -1)
real_img = Variable(img).cuda()
real_label = Variable(torch.ones(num_img)).cuda()
fake_label = Variable(torch.zeros(num_img)).cuda()
我不明白训练代码中的torch.ones和torch.zeros是什么。
任何人都可以解释一下吗?
答案 0 :(得分:2)
正如你可能知道的那样:在GAN中,生成器试图通过说服假例子是一个真实的例子来欺骗判别者。鉴别者经过培训,可以区分真实的例子和假的例子。另一方面,生成器被训练以生成(假的)示例,这些示例看起来非常接近真实示例。
分析您共享的代码/示例(在链接中)。
生成器:是一个简单的前馈神经网络。生成器从随机(噪声)分布生成28 * 28
个图像。生成器的目标是生成看起来像真实图像的图像。
判别器:是一个简单的前馈神经网络。鉴别器在给定图像的情况下提供sigmoid([0,1])分数。鉴别器的目标是对假图像给出低分(~0)并对真实图像给出高分(~1)。从本质上讲,鉴别者想要区分真实图像和假图像。
代码如何运作?
首先,提供鉴别器的实际图像的示例,并且基于鉴别器的预测得分计算损失。
# compute loss of real_img
real_out = D(real_img)
d_loss_real = criterion(real_out, real_label)
real_scores = real_out # closer to 1 means better
然后,鉴别器被提供由生成器生成的伪图像。根据鉴别者对假例子的得分计算损失。
# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()
fake_img = G(z)
fake_out = D(fake_img)
d_loss_fake = criterion(fake_out, fake_label)
fake_scores = fake_out # closer to 0 means better
基本上,发电机和鉴别器相互竞争,成为实现目标的专家。我们可以用这种方式思考:如果我们有一个完美的生成器,那么它将创建完全看起来像真实的假例子,鉴别器将无法区分它们,反之亦然。
您在上面提供的代码只是使用torch.zeros()
和torch.ones()
创建标签。您可以简单地将其视为真实和虚假图像的二进制标签。