我需要一些关于GAN代码的解释

时间:2018-03-16 02:42:16

标签: pytorch

Here is github codes

for epoch in range(num_epoch):
for i, (img, _) in enumerate(dataloader):
    num_img = img.size(0)
    # =================train discriminator
    img = img.view(num_img, -1)
    real_img = Variable(img).cuda()
    real_label = Variable(torch.ones(num_img)).cuda()
    fake_label = Variable(torch.zeros(num_img)).cuda()

我不明白训练代码中的torch.ones和torch.zeros是什么。

任何人都可以解释一下吗?

1 个答案:

答案 0 :(得分:2)

正如你可能知道的那样:在GAN中,生成器试图通过说服假例子是一个真实的例子来欺骗判别者。鉴别者经过培训,可以区分真实的例子和假的例子。另一方面,生成器被训练以生成(假的)示例,这些示例看起来非常接近真实示例。

分析您共享的代码/示例(在链接中)。

生成器:是一个简单的前馈神经网络。生成器从随机(噪声)分布生成28 * 28个图像。生成器的目标是生成看起来像真实图像的图像。

判别器:是一个简单的前馈神经网络。鉴别器在给定图像的情况下提供sigmoid([0,1])分数。鉴别器的目标是对假图像给出低分(~0)并对真实图像给出高分(~1)。从本质上讲,鉴别者想要区分真实图像和假图像。

代码如何运作?

首先,提供鉴别器的实际图像的示例,并且基于鉴别器的预测得分计算损失。

# compute loss of real_img
real_out = D(real_img)
d_loss_real = criterion(real_out, real_label)
real_scores = real_out  # closer to 1 means better

然后,鉴别器被提供由生成器生成的伪图像。根据鉴别者对假例子的得分计算损失。

# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()
fake_img = G(z)
fake_out = D(fake_img)
d_loss_fake = criterion(fake_out, fake_label)
fake_scores = fake_out  # closer to 0 means better

基本上,发电机和鉴别器相互竞争,成为实现目标的专家。我们可以用这种方式思考:如果我们有一个完美的生成器,那么它将创建完全看起来像真实的假例子,鉴别器将无法区分它们,反之亦然。

您在上面提供的代码只是使用torch.zeros()torch.ones()创建标签。您可以简单地将其视为真实和虚假图像的二进制标签。