我正在运行基于DCGAN的GAN,并在尝试使用WGAN,但是对于如何训练WGAN感到有些困惑。
在官方Wasserstein GAN PyTorch implementation中,每一次生成器训练都将歧视者/批评者训练过Diters
(通常5次)。
这是否意味着评论者/区分者在Diters
个批次或整个数据集 Diters
上进行训练?如果我没记错的话,官方的实施建议对鉴别器/批评者进行整个数据集 Diters
的培训,但是WGAN的其他实施(在PyTorch和TensorFlow等中)相反。
哪个是正确的? The WGAN paper(至少对我而言)表示Diters
个批次。整个数据集的训练显然要慢几个数量级。
谢谢!
答案 0 :(得分:1)
正确的做法是将迭代视为批处理。
在原始paper中,对于注释者/鉴别者的每次迭代,他们正在采样一批大小为m
的真实数据和一批大小为m
的先前样本{{1} }来工作。在评论者经过p(z)
次迭代训练之后,他们训练了生成器,该生成器也通过对Diters
的一批先前样本进行采样而开始。
因此,每个迭代都在批量处理。
在official implementation中,这也正在发生。可能令人困惑的是,它们使用变量名p(z)
表示训练模型的时期数。尽管他们使用另一种方案在162-166行设置niter
:
Diters
如本文中所述,他们正在对# train the discriminator Diters times
if gen_iterations < 25 or gen_iterations % 500 == 0:
Diters = 100
else:
Diters = opt.Diters
个批次的评论家进行培训。
答案 1 :(得分:0)
此WGAN的实现将其显示为生成器每次运行时鉴别器的Diter批次-https://github.com/shayneobrien/generative-models/blob/74fbe414f81eaed29274e273f1fb6128abdb0ff5/src/w_gan.py#L88