我想训练我的Tensorflow模型,冻结快照,然后使用新的输入数据以前馈模式(无需进一步培训)运行它。问题:
tf.train.export_meta_graph
和tf.train.import_meta_graph
是否适用于此工具?collection_list
中包含我想要包含在快照中的所有变量的名称? (对我来说最简单的就是包括所有内容。)collection_list
,则将导出模型中的所有集合。”这是否意味着如果我在collection_list
中未指定变量,那么模型中的所有变量都会导出,因为它们位于默认集合中?to_proto()
和from_proto()
必须仅添加到我已定义并希望导出的类中?如果我只使用标准的Python数据类型(int,float,list,dict)那么这是不相关的吗?提前致谢。
答案 0 :(得分:3)
有点晚但我还是会尝试回答。
- 醇>
tf.train.export_meta_graph
和tf.train.import_meta_graph
是否适用于此工具?
我会这么说。请注意,当您通过tf.train.export_meta_graph
保存模型时,会隐式调用tf.train.Saver
。要点是:
# create the model
...
saver = tf.train.Saver()
with tf.Session() as sess:
...
# save graph and variables
# if you are using global_step, the saver will automatically keep the n=5 latest checkpoints
saver.save(sess, save_path, global_step)
然后恢复:
save_path = ...
latest_checkpoint = tf.train.latest_checkpoint(save_path)
saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
请注意,您可以调用原先用于创建模型的原始代码,而不是调用tf.train.import_meta_graph
。但是,我认为使用import_meta_graph
更优雅,因为即使您无法访问创建它的代码,也可以恢复模型。
- 我是否需要在
醇>collection_list
中包含我想要包含在快照中的所有变量的名称? (对我来说最简单的就是包括所有内容。)
没有。但问题有点令人困惑:collection_list
中的export_meta_graph
并不是变量列表,而是集合(即字符串键列表)。
收藏非常方便,例如所有可训练变量都自动包含在集合tf.GraphKeys.TRAINABLE_VARIABLES
中,您可以通过调用来获取:
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
或
tf.trainable_variables() # defaults to the default graph
如果在恢复之后您需要访问除可训练变量之外的其他中间结果,我发现将它们放入自定义集合非常方便,如下所示:
...
input_ = tf.placeholder(tf.float32, shape=[64, 64])
....
tf.add_to_collection('my_custom_collection', input_)
自动存储此集合(除非您通过在collection_list
参数中省略export_meta_graph
参数中此集合的名称来明确指定。因此,您可以在恢复后简单地检索input_
占位符,如下所示:
...
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
input_ = tf.get_collection_ref('my_custom_collection')[0]
- Tensorflow文档说:“如果未指定
醇>collection_list
,则将导出模型中的所有集合。”这是否意味着如果我在{{1}中未指定任何变量然后导出模型中的所有变量,因为它们在默认集合中?
是。再次注意collection_list
是集合列表而不是变量的细微细节。实际上,如果您只想保存某些变量,则可以在构造collection_list
对象时指定这些变量。来自tf.train.Saver
:
tf.train.Saver.__init__
- Tensorflow文档说:“为了将Python对象与MetaGraphDef进行序列化,Python类必须实现
醇>"""Creates a `Saver`. The constructor adds ops to save and restore variables. `var_list` specifies the variables that will be saved and restored. It can be passed as a `dict` or a list: * A `dict` of names to variables: The keys are the names that will be used to save or restore the variables in the checkpoint files. * A list of variables: The variables will be keyed with their op name in the checkpoint files.
和to_proto()
方法,并将其注册到系统中 使用register_proto_function。“这是否意味着from_proto()
和to_proto()
必须仅添加到我定义的类中 想出口?如果我只使用标准的Python数据类型(int, float,list,dict)那么这是无关紧要的吗?
我从未使用过此功能,但我会说你的解释是正确的。