在Estimator SessionRunHook中更改tf.Variable值

时间:2018-08-10 10:50:31

标签: python tensorflow

我有一个tf.Estimator,其model_fn包含一个初始化为1.0的tf.Variable。我想根据开发集的准确性在每个时期更改变量值。我实现了SessionRunHook来实现此目的,但是当我尝试更改该值时,出现以下错误:

raise RuntimeError("Graph is finalized and cannot be modified.")

这是挂钩的代码:

    class DynamicWeightingHook(tf.train.SessionRunHook):
        def __init__(self, epoch_size, gamma_value):
            self.gamma = gamma_value
            self.epoch_size = epoch_size
            self.steps = 0

        def before_run(self, run_context):
            self.steps += 1

        def after_run(self, run_context, run_values):
            if self.steps % epoch_size == 0:  # epoch 
                with tf.variable_scope("lambda_scope", reuse=True):
                    lambda_tensor = tf.get_variable("lambda_value")
                    tf.assign(lambda_tensor, self.gamma_value)
                    self.gamma_value += 0.1

我了解在运行该挂钩时该图已完成,但是我想知道在训练过程中是否还有其他方法可以使用Estimator API更改model_fn图中的变量值。

1 个答案:

答案 0 :(得分:2)

现在设置钩子的方式实际上是在每次会话运行后尝试创建新的变量/操作。相反,您应该预先定义tf.assign op并将其传递给钩子,以便它可以在必要时运行op本身,或者在钩子的__init__中定义assign op。您可以通过after_run参数访问run_context内部的会话。像

class DynamicWeightingHook(tf.train.SessionRunHook):
    def __init__(self, epoch_size, gamma_value, lambda_tensor):
        self.gamma = gamma_value
        self.epoch_size = epoch_size
        self.steps = 0
        self.update_op = tf.assign(lambda_tensor, self.gamma_placeholder)

    def before_run(self, run_context):
        self.steps += 1

    def after_run(self, run_context, run_values):
        if self.steps % epoch_size == 0:  # epoch 
            run_context.session.run(self.update_op)
            self.gamma += 0.1

这里有一些警告。首先,我不确定您是否可以使用这样的Python整数来执行tf.assign,即一旦更改gamma后它是否可以正确更新。如果这不起作用,您可以尝试以下方法:

class DynamicWeightingHook(tf.train.SessionRunHook):
    def __init__(self, epoch_size, gamma_value, lambda_tensor):
        self.gamma = gamma_value
        self.epoch_size = epoch_size
        self.steps = 0
        self.gamma_placeholder = tf.placeholder(tf.float32, [])
        self.update_op = tf.assign(lambda_tensor, self.gamma_placeholder)

    def before_run(self, run_context):
        self.steps += 1

    def after_run(self, run_context, run_values):
        if self.steps % epoch_size == 0:  # epoch 
            run_context.session.run(self.update_op, feed_dict={self.gamma_placeholder: self.gamma})
            self.gamma += 0.1

在这里,我们使用了一个额外的占位符,以便能够始终将“当前”伽玛传递给assign op。

第二,由于挂钩需要访问变量,因此您需要在模型函数内定义挂钩。您可以在EstimatorSpec(请参阅here)中将这样的钩子传递给训练过程。