我有一个带有钩子的训练工作,可以在训练过程中设置变量:
class MyHook(tf.train.SessionRunHook):
...
def begin(self):
self._assign_op = tf.assign(...)
self._assign_placeholder = ...
def after_run(self, run_context, run_values):
...
run_context.session.run(self._assign_op, feed_dict={self._assign_placeholder: ...})
培训是像往常一样进行的:
my_model.train(..., hooks=[MyHook()])
现在,我试图了解如果我转到分布式培训中会发生什么。挂钩可以在所有机器上运行吗?如果当前我使用session.graph.get_tensor_by_name()
在挂钩中找到要分配给张量的张量,那么这项工作还是在不同的副本上张量具有不同的名称?
换句话说,在多GPU或多节点训练的情况下,要使类似的东西工作,我需要考虑什么?