_ = importer.import_graph_def(input_graph_def, name='')
with session.Session() as sess:
if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(','),
variable_names_blacklist=variable_names_blacklist)
在上面的代码中,导入器用于将graphDef导入当前默认图形,并且保护程序加载以前训练的值。问题是存储这些训练值的位置?在会话中,在input_graph_def中,在当前图形结构(tf.get_default_graph())中还是在保护程序中?
我检查方法convert_variables_to_constants
的代码。 https://github.com/tensorflow/tensorflow/blob/235192d47cfb375c0cc93c1deefb9e440715bf35/tensorflow/python/framework/graph_util_impl.py
它使用sess.run(变量名)来获取加载的值。这个sess.run
从哪里获取值?
答案 0 :(得分:0)
当我们定义保护程序时,我们应该将注释传递给它(默认情况下它是全局变量)。
In [2]: import tensorflow as tf
In [3]: a = tf.get_variable("a", [])
In [5]: saver_a = tf.train.Saver({"my_a_in_ckpt": a}) # here "my_a_in_ckpt" should match that as you defined in step (1)
In [7]: sess = tf.Session()
In [9]: saver_a.restore(sess, tf.train.latest_checkpoint("./temp_model"))
INFO:tensorflow:Restoring parameters from ./temp_model/temp
In [10]: sess.run(a)
Out[10]: 0.43891537
In [11]: sess.run(b)
Out[11]: 1.5962805
这里我们首先初始化所有变量并保存到" ./ temp_model"。要恢复变量:
In [12]: saver_b = tf.train.Saver({"b": b})
In [13]: saver_b.save(sess, "./temp_model_b/temp")
Out[13]: './temp_model_b/temp'
我们可以将a和b保存到不同的地方:
In [3]: a = tf.get_variable("a", [])
In [4]: b = tf.get_variable("b", [])
In [5]: saver_b = tf.train.Saver({"b": b})
In [6]: saver_a = tf.train.Saver({"my_a_in_ckpt": a})
In [7]: saver_a.restore(sess, tf.train.latest_checkpoint("./temp_model"))
INFO:tensorflow:Restoring parameters from ./temp_model/temp
In [8]: saver_b.restore(sess, tf.train.latest_checkpoint("./temp_model_b"))
INFO:tensorflow:Restoring parameters from ./temp_model_b/temp
In [9]: sess.run(a)
Out[9]: 0.43891537
In [10]: sess.run(b)
Out[10]: 1.5962805
并将其恢复为图表:
{{1}}