我正在使用tf.estimator
API来训练模型。
据我了解,model_fn
定义了计算图,根据tf.estimator.EstimatorSpec
返回不同的mode
。
在mode==tf.estimator.ModeKeys.TRAIN
中,可以指定在每次训练迭代时调用train_op
,然后更改trainable
tf.Variable
个实例,以优化某种损失。
让我们拨打train_op optimizer
,以及变量A
和B
。
为了加快预测和评估,我希望有一个辅助的不可训练 tf.Variable
Tensor C
,完全依赖关于已经训练过的变量。因此,这个张量的值是可输出的。此Tensor不会影响训练损失。让我们假设我们想要:
C = tf.Variable(tf.matmul(A,B))
update_op = tf.assign(C, tf.matmul(A,B))
将tf.group(optimizer, update_op)
train_op
作为EstimatorSpec
传递给train_op
效果很好但是减慢了很多培训,因为C
现在更新了{{1}在每次迭代时。
由于C
仅在评估/预测时间需要,因此在培训结束时拨打update_op
就足够了。
是否可以在培训结束时指定一个变量tf.estimator.Estimator
?
答案 0 :(得分:2)
一般来说,模型函数的单次迭代不知道训练是否会在运行后结束,所以我怀疑这可以直接完成。我看到两个选择:
如果您仅在培训后需要辅助变量,则可以使用tf.estimator.Estimator.get_variable_value
(请参阅here)提取变量A
和B
之后的值训练为numpy数组并进行计算以获得C
。然而,C
不会成为模型的一部分。
使用钩子(见here)。您可以使用end
方法编写一个钩子,该方法将在会话结束时调用(即训练停止时)。你可能需要研究如何定义/使用钩子 - 例如here你可以找到大多数" basic"的实现。钩子已经在Tensorflow中了。粗糙的骨架看起来像这样:
class UpdateHook(SessionRunHook):
def __init__(update_variable, other_variables):
self.update_op = tf.assign(update_variable, some_fn(other_variables))
def end(session):
session.run(self.update_op)
由于钩子需要访问变量,因此需要在模型函数中定义钩子。您可以将此类挂钩传递到EstimatorSpec
中的培训流程(请参阅here)。
我还没有测试过这个!我不确定你是否可以在钩子中定义操作。如果没有,它应该有效地在模型函数中定义更新操作并直接将其传递给钩子。
答案 1 :(得分:0)
使用挂钩是一种解决方案。 但是请注意,如果要更改变量的值,则不得在end()函数中进行更改,因为更改的结果无法存储在检查点文件中。例如,如果在after_run函数中更改该值,则结果将存储在检查点中。