如何从tf.estimator获取默认会话?

时间:2018-03-15 08:10:42

标签: python tensorflow

我正在尝试使用高API tf.estimator,但我发现很难让会话调试某些内部结果,例如全局步骤。

cls = tf.estimator.Estimator(
    model_fn=my_model,
    params={
        'feature_columns': fcs,
        'hidden_units': [10, 10],
        'n_classes': 3,
    })

来自https://www.tensorflow.org/versions/master/get_started/custom_estimators

的示例

我已尝试sess = tf.get_default_sessionwith tf.Session() as sess,但无法获得defut会话。

1 个答案:

答案 0 :(得分:3)

最简单的方法是使用tf.Print,如:

...
global_step = tf.Print(global_step, [global_step], message='Value of global step")
...

您可以将global_step替换为您要打印的任何张量。然后,当您运行训练时,它将在每次评估张量时打印值。

另一种更复杂的方法是导出模型,然后使用您自己的会话(而不是估算器api)将其加载回来。完成此操作后,您可以针对任何已定义的操作调用session.run。您可以使用tf.get_operation_by_nametf.get_tensor_by_name进行操作。您还可以提供您想要的任何值作为输入。