我正在尝试使用tf.contrib.gan.estimator.GANEstimator(https://www.tensorflow.org/api_docs/python/tf/contrib/gan/estimator/GANEstimator)训练Wasserstein生成对抗网络。但是,我不是使用生成器网络将噪声映射到图像,而是使用了将“不良图像”映射为“良好图像”的优化程序。因此输入和输出尺寸是相同的。
我只想对批评者网络进行交叉检查训练,以查看梯度惩罚损失和批评者损失是否正在收敛。生成器网络应该只传递图像,但我不知道如何为GANEstimator提供生成器功能,从而做到这一点。
当刚返回输入时,我得到一个AssertionError: assert variables_to_train
。
将输入乘以1时 发生了一些非常奇怪的事情,例如梯度惩罚损失大部分为零,除了一些尖峰,并且损失正在发散:
def generator_network(self, images, mode):
noop = tf.get_variable("noop", shape=[], initializer=tf.initializers.ones())
return images * noop