如何从GAN训练发电机?

时间:2019-05-11 16:44:34

标签: tensorflow keras neural-network deep-learning generative-adversarial-network

阅读GAN教程和代码示例后,我仍然不了解如何训练generator。假设我们有一个简单的案例: -发生器输入为噪声,输出为灰度图像10x10 -鉴别器输入为10x10图像,输出为0到1(伪或真)的单个值

训练鉴别器很容易-将其输出真实,并期望为1。取得假冒输出,并期望为0。我们正在处理实际输出大小-单个值。

但是训练生成器是不同的-我们将假输出(1值)作为该值的预期输出。但这听起来更像是对描述符的培训。生成器的输出是图像10x10,我们如何仅用1个单个值对其进行训练?在这种情况下,反向传播可能如何工作?

2 个答案:

答案 0 :(得分:2)

要训练生成器,您必须在整个组合模型中向后传播,同时冻结鉴别器的权重,以便仅更新生成器。

为此,我们必须计算d(g(z; θg); θd),其中θg和θd是生成器和鉴别器的权重。要更新生成器,我们可以计算梯度wrt。仅将{g}设置为θg,然后使用正常梯度下降更新θg。

在Keras中,这可能看起来像这样(使用功能性API):

∂loss(d(g(z; θg); θd)) / ∂θg

通过将genInput = Input(input_shape) discriminator = ... generator = ... discriminator.trainable = True discriminator.compile(...) discriminator.trainable = False combined = Model(genInput, discriminator(generator(genInput))) combined.compile(...) 设置为False,不影响已编译的模型,仅冻结以后编译的模型。因此,该鉴别器可以作为独立模型进行训练,但可以冻结在组合模型中。

然后,训练您的GAN:

trainable

答案 1 :(得分:0)

我想理解生成器训练过程的最好方法是修改所有训练循环。

对于每个时期:

  1. 更新鉴别器:

    • 前向真实图像迷你批处理通过鉴别器;

    • 计算鉴别器损耗并计算后向传递的梯度;

    • 通过Generator生成小批量的伪图像;

    • 通过鉴别器向前生成伪造的小批量通行证;

    • 计算鉴别器损耗并为后向传递得出梯度;

    • 添加(真实的小批量渐变,伪造的小批量渐变)

    • 更新鉴别器(使用Adam或SGD)。

  2. 更新生成器:

    • 翻转目标:对于Generator,伪造的图像被标记为真实图像。注意:此步骤可确保对生成器使用交叉熵最小化。如果我们继续实施GAN minmax游戏,它将有助于克服Generator消失的梯度问题。

    • 转发伪造的图像小批量通过更新的鉴别器;

    • 根据更新的鉴别器输出来计算生成器损耗,例如:

    损失函数(伪造图像由鉴别器真实估计的概率为1)。
    注意:此处1代表假图片的真实标签。

    • 更新Generator(使用Adam或SGD)

我希望这会有所帮助。从训练过程中可以看出,GAN参与者在某种程度上是“合作的,因为鉴别器会估算数据与模型分布密度的比率,然后与生成器自由共享此信息。从这个角度来看,鉴别器是比起对手,更像是老师指示生成器如何进行改进”(引自I.Goodfellow tutorial)。