如何在tensorflow中恢复部分图?

时间:2017-09-07 20:22:13

标签: python machine-learning tensorflow deep-learning

我想在张量流中仅恢复计算图的一部分。我的架构包含两个网络。第一网络的输出是第二网络的输入。第一个网络是预训练的,我想从检查点恢复。我也不想更新第一个网络的参数。是否有一个我可以遵循的例子来实现这个目标?

由于

1 个答案:

答案 0 :(得分:4)

我没有确切的代码,但这里有一个简短的指南,可以帮助您:

首先您需要将网络解析为tf.GraphDef格式 代码应该是这样的:

graph_def = tf.GraphDef()
with tf.gfile.FastGFile("path/to/graphdef") as f:
  s = f.read()
graph_def.ParseFromString(s)

或从检查点/ saved_mode恢复,然后通过以下方式转换为GraphDef

tf.train.import_meta_graph('checkpoint.meta')
tf.get_default_graph().as_graph_def()

现在你有了graph_def

第二次,使用graph_def提取tf.graph_util.extract_sub_graph的子图,您也可以指定您输入到第二个网络的目标节点。

上次,使用tf.import_graph_def从第二步导入子图。

,因为您不想更新第一个网络的参数,您可以使用tf.graph_util.convert_variables_to_constants冻结其参数