生成器的优化器是否也训练鉴别器?

时间:2020-03-20 22:25:23

标签: pytorch

在学习GAN时,我注意到代码示例展示了这种模式:

鉴别器是这样训练的:

d_optim.zero_grad()

real_pred = d(real_batch)
d_loss = d_loss_fn(real_pred, torch.ones(real_batch_size, 1))
d_loss.backward()

fake_pred = d(g(noise_batch).detach())
d_loss = d_loss_fn(fake_pred, torch.zeros(noise_batch_size, 1))
d_loss.backward()

d_optim.step()

生成器是这样训练的:

g_optim.zero_grad()

fake_pred = d(g(noise_batch))
g_loss = g_loss_fn(fake_pred, torch.ones(noise_batch_size, 1))
g_loss.backward()

g_optim.step()

有人提到,d(g(noise_batch).detach())是为区分符而不是d(g(noise_batch))写的,以防止d_optim.step()训练g,但对于d(g(noise_batch))发电机g_optim.step()也会训练d吗?

实际上,为什么我们要d(g(noise_batch).detach()),例如d_optim = torch.optim.SGD(d.parameters(), lr=0.001)?这不是指定d.parameters()还是g.parameters()是否要更新?

1 个答案:

答案 0 :(得分:2)

TLDR:optimizer将仅更新为其指定的参数,而backward()调用将计算计算图中所有变量的梯度。因此,detach()那时不需要进行梯度计算的变量很有用。

我相信答案在于在PyTorch中实现事物的方式。

  • tensor.detach()创建一个张量,该张量与不需要渐变的tensor共享存储。因此,有效地切断了计算图。也就是说,执行fake_pred = d(g(noise_batch).detach())将分离(切断)生成器的计算图。
  • 在损耗上调用backward()时,将为整个计算图计算梯度(与优化器是否使用梯度无关)。因此,切断发电机部件将避免对发电机权重进行梯度计算(因为它们不是必需的)。
  • 此外,在调用optimizer时仅更新传递给特定optimizer.step()的参数。因此,g_optim只会优化传递给它的参数(您没有明确提及传递给g_optim的参数)。同样,d_optim仅会更新d.parameters(),因为您已明确指定。