如何在Tensorflow中异步更新GAN Generator和Discriminator?

时间:2017-12-25 22:11:01

标签: tensorflow generative adversarial-machines

我想用Tensorflow开发一个GAN,其中Generator是一个自动编码器,而Discriminator是一个带二进制输出的卷积神经网络。开发自动编码器和CNN没有问题,但我的想法是为每个组件(判别器和发生器)训练1个纪元,并重复这个循环1000个纪元,保持前一个训练纪元的结果(权重)为下一个。我该如何操作呢?

2 个答案:

答案 0 :(得分:1)

如果您有两个名为train_step_generatortrain_step_discriminator的操作(例如,每个操作都是tf.train.AdamOptimizer().minimize(loss)形式,每个操作都有适当的损失),那么您的训练循环应该是类似于以下结构:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(1000):
        if epoch%2 == 0: # train discriminator on even epochs
            for i in range(training_set_size/batch_size):
                z_ = np.random.normal(0,1,batch_size) # this is the input to the generator
                batch = get_next_batch(batch_size)
                sess.run(train_step_discriminator,feed_dict={z:z_, x:batch})
        else: # train generator on odd epochs
            for i in range(training_set_size/batch_size):
                z_ = np.random.normal(0,1,batch_size)  # this is the input to the generator
                sess.run(train_step_generator,feed_dict={z:z_})

权重将在迭代之间保持不变。

答案 1 :(得分:0)

我解决了这个问题。实际上,我希望自动编码器的输出是CNN的输入,连接GAN并以1:1的比例更新权重。我注意到我必须特别注意区分发生器和鉴别器的损耗,否则在第二次循环开始时,发生器的张量损失将被浮点数替换,即Discriminator产生的最后一次丢失。

这是代码:

with tf.Session() as sess:
sess.run(init)
for i in range(1, num_steps+1):

这里是发电机培训

    batch_x, batch_y=next_batch(batch_size, x_train_noisy, x_train)        
    _, l = sess.run([optimizer, loss], feed_dict={X: batch_x.reshape(n,784),
                    Y:batch_y})
    if i % display_step == 0 or i == 1:
        print('Epoch %i: Denoising Loss: %f' % (i, l))

此处,Generator的输出将用作Discriminator的输入

    output=sess.run([decoder_op],feed_dict={X: x_train})
    x_train2=np.array(output).reshape(n,784).astype(np.float64)

这里是判别者培训

    batch_x2, batch_y2 = next_batch(batch_size, x_train2, y_train)
    sess.run(train_op, feed_dict={X2: batch_x2.reshape(n,784), Y2: batch_y2, keep_prob: 0.8})
    if i % display_step == 0 or i == 1:
        loss3, acc = sess.run([loss_op2, accuracy], feed_dict={X2: batch_x2,
                                                             Y2: batch_y2,
                                                             keep_prob: 1.0})
        print("Epoch " + str(i) + ", CNN Loss= " + \
              "{:.4f}".format(loss3) + ", Training Accuracy= " + "{:.3f}".format(acc))

这样,异步更新可以按比例1:1,1:5,5:1(判别器:生成器)或任何其他方式操作