在`tf.estimator`中,如何在训练结束时(而不是在每次迭代时)对变量进行“tf.assign”?

时间:2018-05-18 07:27:12

标签: python tensorflow machine-learning deep-learning

我正在使用tf.estimator API来训练模型。

据我了解,model_fn定义了计算图,根据tf.estimator.EstimatorSpec返回不同的mode

mode==tf.estimator.ModeKeys.TRAIN中,可以指定在每次训练迭代时调用train_op,然后更改trainable tf.Variable个实例,以优化某种损失。

让我们拨打train_op optimizer,以及变量AB

为了加快预测和评估,我希望有一个辅助的不可训练 tf.Variable Tensor C ,完全依赖关于已经训练过的变量。因此,这个张量的值是可输出的。此Tensor不会影响训练损失。让我们假设我们想要:

C = tf.Variable(tf.matmul(A,B))
update_op = tf.assign(C, tf.matmul(A,B))
  • 我尝试了什么:

tf.group(optimizer, update_op) train_op作为EstimatorSpec传递给train_op 效果很好但是减慢了很多培训,因为C现在更新了{{1}在每次迭代时。

由于C仅在评估/预测时间需要,因此在培训结束时拨打update_op就足够了。

是否可以在培训结束时指定一个变量tf.estimator.Estimator

2 个答案:

答案 0 :(得分:2)

一般来说,模型函数的单次迭代不知道训练是否会在运行后结束,所以我怀疑这可以直接完成。我看到两个选择:

  1. 如果您仅在培训后需要辅助变量,则可以使用tf.estimator.Estimator.get_variable_value(请参阅here)提取变量AB之后的值训练为numpy数组并进行计算以获得C。然而,C不会成为模型的一部分。

  2. 使用钩子(见here)。您可以使用end方法编写一个钩子,该方法将在会话结束时调用(即训练停止时)。你可能需要研究如何定义/使用钩子 - 例如here你可以找到大多数" basic"的实现。钩子已经在Tensorflow中了。粗糙的骨架看起来像这样:

    class UpdateHook(SessionRunHook):
        def __init__(update_variable, other_variables):
            self.update_op = tf.assign(update_variable, some_fn(other_variables))
    
        def end(session):
            session.run(self.update_op)
    

    由于钩子需要访问变量,因此需要在模型函数中定义钩子。您可以将此类挂钩传递到EstimatorSpec中的培训流程(请参阅here)。

    我还没有测试过这个!我不确定你是否可以在钩子中定义操作。如果没有,它应该有效地在模型函数中定义更新操作并直接将其传递给钩子。

答案 1 :(得分:0)

使用挂钩是一种解决方案。 但是请注意,如果要更改变量的值,则不得在end()函数中进行更改,因为更改的结果无法存储在检查点文件中。例如,如果在after_run函数中更改该值,则结果将存储在检查点中。