SyncReplicasOptimizer + MonitoredTrainingSession通过一个会话多次运行网络。运行

时间:2018-07-18 01:39:02

标签: python tensorflow

描述问题

tf.train.MonitoredTrainingSessiontf.train.SyncReplicasOptimizer组合而成的

训练模型, 我发现模型网络将仅通过一次session.run调用而运行两次,并且会提高许多NaN值。更多,这将更新批处理规范化参数。 我通过tf.train.SyncReplicasOptimizer.make_session_run_hook进行初始化。顺便说一句,没有tf.train.SyncReplicasOptimizer的代码将是正确的。

源代码/日志

# 1. model define
input = tf.placeholder_with_default()
phase = tf.placeholder_with_default(True, shape=(), name='phase')
... network define ...
self.loss = xxx

# 2. optimizer define
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.optimizer = tf.train.SyncReplicasOptimizer(self.optimizer, replicas_to_aggregate=worker_num, total_num_replicas=worker_num)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
     self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)

# 3. hook define
sync_replicas_hook = self.optimizer.make_session_run_hook((flags.task_index == 0), num_tokens=0)

# 4. train
with tf.train.MonitoredTrainingSession(master=server.target, hooks=[sync_replicas_hook], is_chief=is_chief):
    outpus = sess.run([self.train_op, global_step], feed_dict={'phase:0':True, ...})

问题A:

我在网络中定义时添加了tf.Print行,我发现这行执行一次sess.run调用两次。

问题B:

phase = True时,BatchNorm参数正确更新,但是当phase = Falsemoving_meanmoving_variance变为NaN时,丢失也是如此。

使用异步优化器进行培训时,一切都很好。

问题C:

我不确定是否有人尝试过SyncReplicasOptimizer + MonitoredTrainingSession +模型输入的占位符。创建会话后,在挂钩函数中初始化变量期间,它会报告需要将值提供给占位符。我很困惑在session.run显式调用之前触发网络计算的原因。最后,我更改为使用具有默认值的placeholder_with_default来解决此问题。但是我认为这不是一个好主意。 在文档中找不到很少的信息。

0 个答案:

没有答案
相关问题