train
的{{1}}函数具有以下签名:
tf.estimator.Estimator
我正在训练一个网络,我需要根据相当复杂的算法的结果每隔几步手动设置一些变量,该算法无法在图中实现。是否可以在挂钩中设置变量的值?有谁知道任何示例代码吗?
为了不浪费资源,我不需要在每个培训步骤中都使用钩子。有没有一种方法可以指定我的挂钩仅每N步调用一次?当然,我可以自己做一个计数器,等到我的算法不运行时返回,但这似乎应该是可配置的。
答案 0 :(得分:1)
是应该的!我不确切知道此变量在哪个范围内或如何引用它,因此我只假设您知道它的名称。我基本上是从其他答案here中窃取代码。
只需在训练循环之前创建一个钩子即可:
class VariableUpdaterHook(tf.train.SessionRunHook):
def __init__(self, frequency, variable_name):
# variable name should be like: parent/scope/some/path/variable_name:0
self._global_step_tensor = None
self.variable = None
self.frequency = frequency
self.variable_name = variable_name
def after_create_session(self, session, coord):
self.variable = session.graph.get_tensor_by_name(self.variable_name)
def begin(self):
self._global_step_tensor = tf.train.get_global_step()
def after_run(self, run_context, run_values):
global_step = run_context.session.run(self._global_step_tensor)
if global_step % self.frequency == 0:
new_variable_value = complicated_algorithm(...)
assign_op = self.variable.assign(new_variable_value)
run_context.session.run(assign_op)
我认为不值得花很多精力研究另一种避免每次迭代后调用的方法,因为它们非常便宜。所以要按照您的建议去做。
注意:由于我目前没有用例,因此我没有时间对其进行调试。但我希望你能明白。