我正在尝试在TPU上训练GAN,所以我一直在搞混TPUEstimator类和随附的模型函数来尝试实现WGAN训练循环。我正在尝试使用tf.cond
来合并TPUEstimatorSpec的两个训练操作:
opt = tf.cond(
tf.equal(tf.mod(tf.train.get_or_create_global_step(),
CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1),
lambda: gen_opt,
lambda: critic_opt
)
gen_opt
和critic_opt
是我正在使用的优化程序的最小化功能,也设置为更新全局步骤。 CRITIC_UPDATES_PER_GEN_UPDATE
就是一个python常量,它是WGAN培训的一部分。我尝试使用tf.cond
找到GAN模型,但是所有模型都使用tf.group
,我不能使用它,因为您需要优化注释器的次数比生成器多得多。
但是,每运行100批,全局步数就会根据检查点数量增加200。我的模型仍在正确训练吗?还是tf.cond
不能用于训练GAN?
答案 0 :(得分:0)
tf.cond
不应以这种方式用于训练GAN。
您之所以得到200,是因为每个训练步骤都评估了{strong> true_fn
和false_fn
的副作用(如赋值操作)。副作用之一是两个优化程序都定义了全局步骤tf.assign_add
。
因此,发生的事情就像
global_step++ (gen_opt)
和global_step++ (critic_op)
的表现true_fn
正文或false_fn
正文(取决于情况)。如果您想使用tf.cond
来训练GAN,则必须从true_fn
/ {{的外部)删除所有辅助操作(例如赋值,因此是优化步骤的定义)。 1}},并声明其中的所有内容。
作为参考,您可以看到有关false_fn
:https://stackoverflow.com/a/37064128/2891324