可以在具有相同结构的不同图中使用tensorflow Saver

时间:2017-07-24 03:40:49

标签: python tensorflow

网络结构已加载到默认全局图中。我想创建另一个具有相同结构的图形,并在此图中加载检查点。

如果代码是这样的,它会抛出错误: ValueError:在最后一行没有要保存的变量 。但是,第二行工作正常。为什么? GraphDef返回的as_graph_def()是否包含变量定义/名称?

inference_graph_def = tf.get_default_graph().as_graph_def()
saver = tf.train.Saver()
with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def)
    saver1 = tf.train.Saver()   

如果这样的代码,则会抛出错误 无法将feed_dict键解释为Tensor:名称' save / Const:0'是指最后一行中不存在的Tensor 。然而,第三行被移除后效果很好。

inference_graph_def = tf.get_default_graph().as_graph_def()
saver = tf.train.Saver()
with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def)    
    with session.Session() as sess:        
        saver.restore(sess, checkpoint_path)

那么,这是否意味着Saver无法在不同的图表中工作,即使它们具有相同的结构?

任何帮助将不胜感激〜

1 个答案:

答案 0 :(得分:3)

以下是使用MetaGraphDef(与GraphDef保存变量集合不同)使用以前保存的图表初始化新图表的示例。

import tensorflow as tf

CHECKPOINT_PATH = "/tmp/first_graph_checkpoint"
with tf.Graph().as_default():
  some_variable = tf.get_variable(
    name="some_variable",
    shape=[2],
    dtype=tf.float32)
  init_op = tf.global_variables_initializer()
  first_meta_graph = tf.train.export_meta_graph()
  first_graph_saver = tf.train.Saver()
  with tf.Session() as session:
    init_op.run()
    print("Initialized value in first graph", some_variable.eval())
    first_graph_saver.save(
        sess=session,
        save_path=CHECKPOINT_PATH)

with tf.Graph().as_default():
  tf.train.import_meta_graph(first_meta_graph)
  second_graph_saver = tf.train.Saver()
  with tf.Session() as session:
    second_graph_saver.restore(
      sess=session,
      save_path=CHECKPOINT_PATH)
    print("Variable value after restore", tf.global_variables()[0].eval())

打印类似:

Initialized value in first graph [-0.98926258 -0.09709156]
Variable value after restore [-0.98926258 -0.09709156]

请注意,检查点仍然很重要!加载MetaGraph不会恢复Variables的值(它不包含这些值),只会跟踪跟踪其存在(收藏)的簿记。 SavedModel format解决了这个问题,将MetaGraph与检查点和其他元数据捆绑在一起运行它们。

编辑:根据大众需求,这是一个用GraphDef做同样事情的例子。我不推荐它。由于在加载GraphDef时没有恢复任何集合,因此我们必须手动指定我们希望Variables恢复的Saver; “import /”默认命名方案很容易通过name=''参数修复import_graph_def,但删除它并不是非常有用,因为如果你需要手动填写变量集合希望Saver能够“自动”工作。相反,我选择在创建Saver时手动指定映射。

import tensorflow as tf

CHECKPOINT_PATH = "/tmp/first_graph_checkpoint"
with tf.Graph().as_default():
  some_variable = tf.get_variable(
    name="some_variable",
    shape=[2],
    dtype=tf.float32)
  init_op = tf.global_variables_initializer()
  first_graph_def = tf.get_default_graph().as_graph_def()
  first_graph_saver = tf.train.Saver()
  with tf.Session() as session:
    init_op.run()
    print("Initialized value in first graph", some_variable.eval())
    first_graph_saver.save(
        sess=session,
        save_path=CHECKPOINT_PATH)

with tf.Graph().as_default():
  tf.import_graph_def(first_graph_def)
  variable_to_restore = tf.get_default_graph().get_tensor_by_name(
      "import/some_variable:0")
  second_graph_saver = tf.train.Saver(var_list={
      "some_variable": variable_to_restore
  })
  with tf.Session() as session:
    second_graph_saver.restore(
      sess=session,
      save_path=CHECKPOINT_PATH)
    print("Variable value after restore", variable_to_restore.eval())