网络结构已加载到默认全局图中。我想创建另一个具有相同结构的图形,并在此图中加载检查点。
如果代码是这样的,它会抛出错误: ValueError:在最后一行没有要保存的变量 。但是,第二行工作正常。为什么? GraphDef
返回的as_graph_def()
是否包含变量定义/名称?
inference_graph_def = tf.get_default_graph().as_graph_def()
saver = tf.train.Saver()
with tf.Graph().as_default():
tf.import_graph_def(inference_graph_def)
saver1 = tf.train.Saver()
如果这样的代码,则会抛出错误 无法将feed_dict键解释为Tensor:名称' save / Const:0'是指最后一行中不存在的Tensor 。然而,第三行被移除后效果很好。
inference_graph_def = tf.get_default_graph().as_graph_def()
saver = tf.train.Saver()
with tf.Graph().as_default():
tf.import_graph_def(inference_graph_def)
with session.Session() as sess:
saver.restore(sess, checkpoint_path)
那么,这是否意味着Saver无法在不同的图表中工作,即使它们具有相同的结构?
任何帮助将不胜感激〜
答案 0 :(得分:3)
以下是使用MetaGraphDef
(与GraphDef
保存变量集合不同)使用以前保存的图表初始化新图表的示例。
import tensorflow as tf
CHECKPOINT_PATH = "/tmp/first_graph_checkpoint"
with tf.Graph().as_default():
some_variable = tf.get_variable(
name="some_variable",
shape=[2],
dtype=tf.float32)
init_op = tf.global_variables_initializer()
first_meta_graph = tf.train.export_meta_graph()
first_graph_saver = tf.train.Saver()
with tf.Session() as session:
init_op.run()
print("Initialized value in first graph", some_variable.eval())
first_graph_saver.save(
sess=session,
save_path=CHECKPOINT_PATH)
with tf.Graph().as_default():
tf.train.import_meta_graph(first_meta_graph)
second_graph_saver = tf.train.Saver()
with tf.Session() as session:
second_graph_saver.restore(
sess=session,
save_path=CHECKPOINT_PATH)
print("Variable value after restore", tf.global_variables()[0].eval())
打印类似:
Initialized value in first graph [-0.98926258 -0.09709156]
Variable value after restore [-0.98926258 -0.09709156]
请注意,检查点仍然很重要!加载MetaGraph
不会恢复Variables
的值(它不包含这些值),只会跟踪跟踪其存在(收藏)的簿记。 SavedModel format解决了这个问题,将MetaGraph
与检查点和其他元数据捆绑在一起运行它们。
编辑:根据大众需求,这是一个用GraphDef
做同样事情的例子。我不推荐它。由于在加载GraphDef
时没有恢复任何集合,因此我们必须手动指定我们希望Variables
恢复的Saver
; “import /”默认命名方案很容易通过name=''
参数修复import_graph_def
,但删除它并不是非常有用,因为如果你需要手动填写变量集合希望Saver
能够“自动”工作。相反,我选择在创建Saver
时手动指定映射。
import tensorflow as tf
CHECKPOINT_PATH = "/tmp/first_graph_checkpoint"
with tf.Graph().as_default():
some_variable = tf.get_variable(
name="some_variable",
shape=[2],
dtype=tf.float32)
init_op = tf.global_variables_initializer()
first_graph_def = tf.get_default_graph().as_graph_def()
first_graph_saver = tf.train.Saver()
with tf.Session() as session:
init_op.run()
print("Initialized value in first graph", some_variable.eval())
first_graph_saver.save(
sess=session,
save_path=CHECKPOINT_PATH)
with tf.Graph().as_default():
tf.import_graph_def(first_graph_def)
variable_to_restore = tf.get_default_graph().get_tensor_by_name(
"import/some_variable:0")
second_graph_saver = tf.train.Saver(var_list={
"some_variable": variable_to_restore
})
with tf.Session() as session:
second_graph_saver.restore(
sess=session,
save_path=CHECKPOINT_PATH)
print("Variable value after restore", variable_to_restore.eval())