在学习GAN时,我注意到代码示例展示了这种模式:
鉴别器是这样训练的:
d_optim.zero_grad()
real_pred = d(real_batch)
d_loss = d_loss_fn(real_pred, torch.ones(real_batch_size, 1))
d_loss.backward()
fake_pred = d(g(noise_batch).detach())
d_loss = d_loss_fn(fake_pred, torch.zeros(noise_batch_size, 1))
d_loss.backward()
d_optim.step()
生成器是这样训练的:
g_optim.zero_grad()
fake_pred = d(g(noise_batch))
g_loss = g_loss_fn(fake_pred, torch.ones(noise_batch_size, 1))
g_loss.backward()
g_optim.step()
有人提到,d(g(noise_batch).detach())
是为区分符而不是d(g(noise_batch))
写的,以防止d_optim.step()
训练g
,但对于d(g(noise_batch))
发电机g_optim.step()
也会训练d
吗?
实际上,为什么我们要d(g(noise_batch).detach())
,例如d_optim = torch.optim.SGD(d.parameters(), lr=0.001)
?这不是指定d.parameters()
还是g.parameters()
是否要更新?
答案 0 :(得分:2)
TLDR:optimizer
将仅更新为其指定的参数,而backward()
调用将计算计算图中所有变量的梯度。因此,detach()
那时不需要进行梯度计算的变量很有用。
我相信答案在于在PyTorch中实现事物的方式。
tensor.detach()
创建一个张量,该张量与不需要渐变的tensor
共享存储。因此,有效地切断了计算图。也就是说,执行fake_pred = d(g(noise_batch).detach())
将分离(切断)生成器的计算图。 backward()
时,将为整个计算图计算梯度(与优化器是否使用梯度无关)。因此,切断发电机部件将避免对发电机权重进行梯度计算(因为它们不是必需的)。optimizer
时仅更新传递给特定optimizer.step()
的参数。因此,g_optim
只会优化传递给它的参数(您没有明确提及传递给g_optim
的参数)。同样,d_optim
仅会更新d.parameters()
,因为您已明确指定。