我可以从估计器获取张量流会话吗?

时间:2019-04-04 10:27:42

标签: tensorflow tensorflow-estimator

我正在使用tf.estimator的LinearRegressor,并将学习率衰减(本来是指数衰减)更改为使用损耗的衰减。但是要做到这一点,我需要将评估损失传递给学习率衰减张量的一些占位符,并且在此步骤中,我需要tf.session。

我尝试tf.get_default_session()来获取估算器进行的会话,但是此会话具有与估算器使用的图不同的图。


    def my_decay(learning_rate, global_step, decay_step, loss, decay_rate):
      # If loss is not reduced, than decay with decay_rate.

    loss = tf.placeholder(tf.float32)
    estimator = tf.estimator.LinearRegressor(
    feature_columns=feature_columns,
    optimizer==lambda: tf.train.FtrlOptimizer(
        learning_rate=my_decay(learning_rate=0.1,
        global_step=tf.get_global_step(), decay_step=10000,
        loss=loss, decay_rate=0.96)),
      config=sess_config
    )

    for _ in range(n_epoches):
      metrics = tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
      session.run(loss.assign(metrics['loss']))

使用上述代码,我需要从估算器中获取session。 有什么办法吗?

提前谢谢!

1 个答案:

答案 0 :(得分:0)

针对此类问题的预期解决方案是子类tf.train.SessionRunHook并重写before_run方法以返回合适的tf.train.SessionRunArgs。这将允许您在训练时输入值,并将提取内容添加到session.run调用中。您的班级将必须在调用之间携带对占位符和loss状态的引用。

然后,您只需实例化该类,然后将钩子添加到您的hooks调用中的estimator.train参数中,或者在本例中将其添加到train_spec中。如果您希望使用评估损失而不是训练损失,则可以通过向eval_spec添加另一个钩子来实现,以读取after_run方法中的值。