在estimator模型函数中使用tf.cond()在TPU上训练WGAN会导致加倍的global_step

时间:2019-01-27 05:48:54

标签: tensorflow generative-adversarial-network tpu

我正在尝试在TPU上训练GAN,所以我一直在搞混TPUEstimator类和随附的模型函数来尝试实现WGAN训练循环。我正在尝试使用tf.cond来合并TPUEstimatorSpec的两个训练操作:

opt = tf.cond(
    tf.equal(tf.mod(tf.train.get_or_create_global_step(), 
    CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1), 
    lambda: gen_opt, 
    lambda: critic_opt
)

gen_optcritic_opt是我正在使用的优化程序的最小化功能,也设置为更新全局步骤。 CRITIC_UPDATES_PER_GEN_UPDATE就是一个python常量,它是WGAN培训的一部分。我尝试使用tf.cond找到GAN模型,但是所有模型都使用tf.group,我不能使用它,因为您需要优化注释器的次数比生成器多得多。  但是,每运行100批,全局步数就会根据检查点数量增加200。我的模型仍在正确训练吗?还是tf.cond不能用于训练GAN?

1 个答案:

答案 0 :(得分:0)

tf.cond不应以这种方式用于训练GAN。

您之所以得到200,是因为每个训练步骤都评估了{strong> true_fnfalse_fn的副作用(如赋值操作)。副作用之一是两个优化程序都定义了全局步骤tf.assign_add

因此,发生的事情就像

  • global_step++ (gen_opt)global_step++ (critic_op)的表现
  • 条件评估
  • 执行true_fn正文或false_fn正文(取决于情况)。

如果您想使用tf.cond来训练GAN,则必须从true_fn / {{的外部)删除所有辅助操作(例如赋值,因此是优化步骤的定义)。 1}},并声明其中的所有内容。

作为参考,您可以看到有关false_fnhttps://stackoverflow.com/a/37064128/2891324

的行为的答案。