我有一个tf.Estimator,其model_fn
包含一个初始化为1.0的tf.Variable。我想根据开发集的准确性在每个时期更改变量值。我实现了SessionRunHook
来实现此目的,但是当我尝试更改该值时,出现以下错误:
raise RuntimeError("Graph is finalized and cannot be modified.")
这是挂钩的代码:
class DynamicWeightingHook(tf.train.SessionRunHook):
def __init__(self, epoch_size, gamma_value):
self.gamma = gamma_value
self.epoch_size = epoch_size
self.steps = 0
def before_run(self, run_context):
self.steps += 1
def after_run(self, run_context, run_values):
if self.steps % epoch_size == 0: # epoch
with tf.variable_scope("lambda_scope", reuse=True):
lambda_tensor = tf.get_variable("lambda_value")
tf.assign(lambda_tensor, self.gamma_value)
self.gamma_value += 0.1
我了解在运行该挂钩时该图已完成,但是我想知道在训练过程中是否还有其他方法可以使用Estimator API更改model_fn图中的变量值。
答案 0 :(得分:2)
现在设置钩子的方式实际上是在每次会话运行后尝试创建新的变量/操作。相反,您应该预先定义tf.assign
op并将其传递给钩子,以便它可以在必要时运行op本身,或者在钩子的__init__
中定义assign op。您可以通过after_run
参数访问run_context
内部的会话。像
class DynamicWeightingHook(tf.train.SessionRunHook):
def __init__(self, epoch_size, gamma_value, lambda_tensor):
self.gamma = gamma_value
self.epoch_size = epoch_size
self.steps = 0
self.update_op = tf.assign(lambda_tensor, self.gamma_placeholder)
def before_run(self, run_context):
self.steps += 1
def after_run(self, run_context, run_values):
if self.steps % epoch_size == 0: # epoch
run_context.session.run(self.update_op)
self.gamma += 0.1
这里有一些警告。首先,我不确定您是否可以使用这样的Python整数来执行tf.assign
,即一旦更改gamma
后它是否可以正确更新。如果这不起作用,您可以尝试以下方法:
class DynamicWeightingHook(tf.train.SessionRunHook):
def __init__(self, epoch_size, gamma_value, lambda_tensor):
self.gamma = gamma_value
self.epoch_size = epoch_size
self.steps = 0
self.gamma_placeholder = tf.placeholder(tf.float32, [])
self.update_op = tf.assign(lambda_tensor, self.gamma_placeholder)
def before_run(self, run_context):
self.steps += 1
def after_run(self, run_context, run_values):
if self.steps % epoch_size == 0: # epoch
run_context.session.run(self.update_op, feed_dict={self.gamma_placeholder: self.gamma})
self.gamma += 0.1
在这里,我们使用了一个额外的占位符,以便能够始终将“当前”伽玛传递给assign op。
第二,由于挂钩需要访问变量,因此您需要在模型函数内定义挂钩。您可以在EstimatorSpec
(请参阅here)中将这样的钩子传递给训练过程。