在tensorflow中重置图形

时间:2017-07-01 16:50:48

标签: python tensorflow

我有一个巨大的全局变量,用于火车和评估,但形状不同。现在我尝试在同一个过程中运行评估和训练,我偶然发现我无法真正删除张量流图中定义的变量。例如here建议的解决方法是使用reset_default_graph(),但这似乎不适用于图形上下文管理器。

import numpy as np
import tensorflow as tf

GRAPH = tf.Graph()

def train(examples):
    with GRAPH.as_default() as g:
        # actually this is huge variable
        global_var = tf.get_variable('global_var',
                                     initializer=np.full((examples, 32), 0.0),
                                     trainable=False)

def evaluate(examples):
    # tf.reset_default_graph() # ValueError: Variable input_var already exists
    with GRAPH.as_default() as g: # initialized to some other size
        tf.reset_default_graph() 
        global_var = tf.get_variable('global_var',
                                     initializer=np.full((examples, 32), 0.0),
                                     trainable=False)
       # in fact tensorflow creates a new graph and does not use GRAPH to define global_var

train(32)
evaluate(8)

结果:

Traceback (most recent call last):
  File "C:/Users/MrD/.PyCharm2017.1/config/scratches/scratch_44.py", line 22, in <module>
    evaluate(8)
  File "C:/Users/MrD/.PyCharm2017.1/config/scratches/scratch_44.py", line 19, in evaluate
    trainable=False)
  File "C:\_\Python35\lib\contextlib.py", line 66, in __exit__
    next(self.gen)
  File "C:\_\Python35\lib\site-packages\tensorflow\python\framework\ops.py", line 3616, in get_controller
    if self.stack[-1] is not default:
IndexError: list index out of range

那么使用reset_default_graph()的正确方法是什么?是否真的没有办法重新定义变量丢弃旧的潜在巨大的初始化器?

1 个答案:

答案 0 :(得分:2)

事实证明,在图形上下文管理器中“重置默认图形”是没有意义的 - 请参阅:https://github.com/tensorflow/tensorflow/issues/11121。较新的版本应该添加更有用的错误消息:

AssertionError: Do not use tf.reset_default_graph() to clear nested graphs. If you need a cleared graph, exit the nesting and create a new graph.

正如上面的问题所讨论并实施here