TensorFlow的钩子和分布式培训

时间:2018-10-22 12:32:02

标签: python tensorflow tensorflow-estimator

我有一个带有钩子的训练工作,可以在训练过程中设置变量:

class MyHook(tf.train.SessionRunHook):
   ...
   def begin(self):
       self._assign_op = tf.assign(...)
       self._assign_placeholder = ...

   def after_run(self, run_context, run_values):
       ...
       run_context.session.run(self._assign_op, feed_dict={self._assign_placeholder: ...})

培训是像往常一样进行的:

my_model.train(..., hooks=[MyHook()])

现在,我试图了解如果我转到分布式培训中会发生什么。挂钩可以在所有机器上运行吗?如果当前我使用session.graph.get_tensor_by_name()在挂钩中找到要分配给张量的张量,那么这项工作还是在不同的副本上张量具有不同的名称?

换句话说,在多GPU或多节点训练的情况下,要使类似的东西工作,我需要考虑什么?

0 个答案:

没有答案