在TensorFlow图中删除除了几个节点之外的所有节点

时间:2016-05-08 17:54:42

标签: python tensorflow deep-learning

我的TensorFlow用例要求我为每个需要处理的实例构建一个新的计算图。这最终会炸掉内存需求。

除了一些tf.Variables模型参数外,我还想删除所有其他节点。其他有类似问题的人发现tf.reset_default_graph()很有用,但这样可以摆脱我需要坚持的模型参数。

我可以使用哪些方法删除除这些节点之外的所有节点?

编辑: 实例特定的计算实际上只意味着我正在添加许多新操作。我相信这些操作是内存问题背后的原因。

更新 查看最近发布的tensorflow fold(https://github.com/tensorflow/fold),它允许动态构建计算图。

1 个答案:

答案 0 :(得分:12)

tf.graph数据结构被设计为仅附加数据结构。因此,无法删除或修改现有节点。通常这不是问题,因为在运行会话时只处理必要的子图。

您可以尝试将图表的Variabels复制到新图表中并删除旧图表。要存档这只是运行:

old_graph = tf.get_default_graph() # Save the old graph for later iteration
new_graph = tf.graph() # Create an empty graph
new_graph.set_default() # Makes the new graph default

如果要迭代旧图中的所有节点,请使用:

for node in old_graph.get_operations():
    if node.type == 'Variable':
       # read value of variable and copy it into new Graph

或者您也可以使用:

for node in old_graph.get_collection('trainable_variables'):
   # iterates over all trainable Variabels
   # read and create new variable

还要查看python/framework/ops.py : 1759以了解更多关于操纵图表中节点的方法。

然而,在您使用tf.Graph之前,我强烈建议您考虑这是否真的需要。通常可以尝试概括计算并使用共享变量构建图形,以便您要处理的每个实例都是该图形的子图。