如何在TensorFlow中获取Estimator的默认会话?

时间:2018-01-16 17:41:44

标签: python tensorflow

我已使用Estimator的export_savedmodel()功能创建了一个Estimator并将其导出到SavedModel文件。

出于可重复性原因,我希望能够重新创建Estimator,在SavedModel文件中加载变量,然后调用evaluate()并获得相同的结果。

我认为这样做的方法是创建我的SessionRunHook来进行加载并将其传递给evaluate()中的hooks参数,如下所示:

class myhook(tf.train.SessionRunHook):

    def begin(self):
        tf.saved_model.loader.load(tf.get_default_session(), ['serve'], '../best_model/1516075471/')


load_best_model_hook = myhook()

res2 = da_model.evaluate(test_input_fn, hooks=[load_best_model_hook])

但这会产生以下错误:

File "/home/user7891/Code/scratch.py", line 106, in begin
    tf.saved_model.loader.load(tf.get_default_session(), ['serve'], '../best_model/1516075471/')
  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/saved_model/loader_impl.py", line 198, in load
    with sess.graph.as_default():
AttributeError: 'NoneType' object has no attribute 'graph'

看起来在调用begin()时未创建会话。我无法覆盖after_create_session,因为此时无法修改图表。

0 个答案:

没有答案