Tensorflow Metagraph Fundamentals

时间:2016-10-01 00:44:22

标签: tensorflow

我想训练我的Tensorflow模型,冻结快照,然后使用新的输入数据以前馈模式(无需进一步培训)运行它。问题:

  1. tf.train.export_meta_graphtf.train.import_meta_graph是否适用于此工具?
  2. 我是否需要在collection_list中包含我想要包含在快照中的所有变量的名称? (对我来说最简单的就是包括所有内容。)
  3. Tensorflow文档说:“如果未指定collection_list,则将导出模型中的所有集合。”这是否意味着如果我在collection_list中未指定变量,那么模型中的所有变量都会导出,因为它们位于默认集合中?
  4. Tensorflow文档说:“为了将Python对象与MetaGraphDef进行序列化,Python类必须实现to_proto()和from_proto()方法,并使用register_proto_function将它们注册到系统中。 / em>“这是否意味着to_proto()from_proto()必须仅添加到我已定义并希望导出的类中?如果我只使用标准的Python数据类型(int,float,list,dict)那么这是不相关的吗?
  5. 提前致谢。

1 个答案:

答案 0 :(得分:3)

有点晚但我还是会尝试回答。

  
      
  1. tf.train.export_meta_graphtf.train.import_meta_graph是否适用于此工具?
  2.   

我会这么说。请注意,当您通过tf.train.export_meta_graph保存模型时,会隐式调用tf.train.Saver。要点是:

# create the model
...
saver = tf.train.Saver()
with tf.Session() as sess:
    ...
    # save graph and variables
    # if you are using global_step, the saver will automatically keep the n=5 latest checkpoints
    saver.save(sess, save_path, global_step)

然后恢复:

save_path = ...
latest_checkpoint = tf.train.latest_checkpoint(save_path)
saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
with tf.Session() as sess:
    saver.restore(sess, latest_checkpoint)

请注意,您可以调用原先用于创建模型的原始代码,而不是调用tf.train.import_meta_graph。但是,我认为使用import_meta_graph更优雅,因为即使您无法访问创建它的代码,也可以恢复模型。

  
      
  1. 我是否需要在collection_list中包含我想要包含在快照中的所有变量的名称? (对我来说最简单的就是包括所有内容。)
  2.   

没有。但问题有点令人困惑:collection_list中的export_meta_graph并不是变量列表,而是集合(即字符串键列表)。

收藏非常方便,例如所有可训练变量都自动包含在集合tf.GraphKeys.TRAINABLE_VARIABLES中,您可以通过调用来获取:

tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

tf.trainable_variables()  # defaults to the default graph

如果在恢复之后您需要访问除可训练变量之外的其他中间结果,我发现将它们放入自定义集合非常方便,如下所示:

...
input_ = tf.placeholder(tf.float32, shape=[64, 64])
....
tf.add_to_collection('my_custom_collection', input_)

自动存储此集合(除非您通过在collection_list参数中省略export_meta_graph参数中此集合的名称来明确指定。因此,您可以在恢复后简单地检索input_占位符,如下所示:

...
with tf.Session() as sess:
    saver.restore(sess, latest_checkpoint)
    input_ = tf.get_collection_ref('my_custom_collection')[0]
  
      
  1. Tensorflow文档说:“如果未指定collection_list,则将导出模型中的所有集合。”这是否意味着如果我在{{1}中未指定任何变量然后导出模型中的所有变量,因为它们在默认集合中?
  2.   

是。再次注意collection_list是集合列表而不是变量的细微细节。实际上,如果您只想保存某些变量,则可以在构造collection_list对象时指定这些变量。来自tf.train.Saver

的文档
tf.train.Saver.__init__
  
      
  1. Tensorflow文档说:“为了将Python对象与MetaGraphDef进行序列化,Python类必须实现    """Creates a `Saver`. The constructor adds ops to save and restore variables. `var_list` specifies the variables that will be saved and restored. It can be passed as a `dict` or a list: * A `dict` of names to variables: The keys are the names that will be used to save or restore the variables in the checkpoint files. * A list of variables: The variables will be keyed with their op name in the checkpoint files. to_proto()方法,并将其注册到系统中   使用register_proto_function。“这是否意味着from_proto()和   to_proto()必须仅添加到我定义的类中   想出口?如果我只使用标准的Python数据类型(int,   float,list,dict)那么这是无关紧要的吗?
  2.   

我从未使用过此功能,但我会说你的解释是正确的。