Tensorflow Estimator:如何在model_fn

时间:2017-10-16 08:40:28

标签: python tensorflow

我有一个非常简短的问题,这可能是一个非常简单的答案,但我无法理解,尽管我现在已经试了几个小时。

我正在使用Tensorflow Estimator,我想访问model_fn中的全局步骤。我试过tf.train.get_global_step,它给我一个Tensor。我需要将global_step作为整数(或作为字符串)!

所以我试过eval()(= tf.get_default_session()。run(t)),但它不起作用..

干杯!

2 个答案:

答案 0 :(得分:0)

您可以使用tf.cast将Tensor转换为intstring

例如,

tf.cast(tf.train.get_global_step(), dtype=tf.int)

参见参考here

答案 1 :(得分:0)

一种方法是从model_dir中的最新检查点文件中解析它。 因此,假设您可以将model_dir传递到model_fn中(通过params的{​​{1}}参数或通过tf.estimator.Estimator(..., params={'model_dir': 'path/to/model_dir'})传递),则可以使用以下实用程序函数:

tf.flags.FLAGS