如何使用tf.reset_default_graph()

时间:2017-07-04 02:55:32

标签: machine-learning tensorflow

每当我尝试使用tf.reset_default_graph()时,我都会收到此错误:IndexError: list index out of range或``。我应该在哪部分代码中使用它?我什么时候应该使用它?

编辑:

我更新了代码,但错误仍然存​​在。

def evaluate():
    with tf.name_scope("loss"):
        global x # x is a tf.placeholder()
        xentropy = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=neural_network(x))
        loss = tf.reduce_mean(xentropy, name="loss")

    with tf.name_scope("train"):
        optimizer = tf.train.AdamOptimizer()
        training_op = optimizer.minimize(loss)

    with tf.name_scope("exec"):
        with tf.Session() as sess:
            for i in range(1, 2):
                sess.run(tf.global_variables_initializer())
                sess.run(training_op, feed_dict={x: np.array(train_data).reshape([-1, 1]), y: label})
                print "Training " + str(i)
                saver = tf.train.Saver()
                saver.save(sess, "saved_models/testing")
                print "Model Saved."


def predict():
    with tf.name_scope("predict"):
        tf.reset_default_graph()
        with tf.Session() as sess:
            saver = tf.train.import_meta_graph("saved_models/testing.meta")
            saver.restore(sess, "saved_models/testing")
            output_ = tf.get_default_graph().get_tensor_by_name('output_layer:0')
            print sess.run(output_, feed_dict={x: np.array([12003]).reshape([-1, 1])})


def main():
    print "Starting Program..."
    evaluate()
    writer = tf.summary.FileWriter("mygraph/logs", tf.get_default_graph())
    predict()

如果我从更新后的代码中删除了tf.reset_default_graph(),我会收到此错误:ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used

根据我目前的理解,tf.reset_default_graph()会删除所有图表,因此我避免了上面提到的错误(ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used

4 个答案:

答案 0 :(得分:14)

这可能就是你使用它的方式:

import tensorflow as tf
a = tf.constant(1)
with tf.Session() as sess:
    tf.reset_default_graph()

您收到错误是因为您在会话中使用它。来自tf.reset_default_graph()文档:

  

在tf.Session或tf.InteractiveSession时调用此函数   active将导致未定义的行为。使用任何以前创建的   调用此函数后,tf.Operation或tf.Tensor对象将会   导致未定义的行为

当我在jupyter笔记本中进行实验时,

tf.reset_default_graph()在测试阶段可能会有所帮助(至少对我而言)。但是,我从来没有在生产中使用它,也没有看到它会如何有用。

以下是可能在笔记本中的示例:

import tensorflow as tf
# create some graph
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(...)

现在我不再需要这些东西,但是如果我创建另一个图形并在tensorboard中可视化,我会看到旧节点和新节点。为了解决这个问题,我可以重启内核并只运行下一个单元。但是,我可以这样做:

tf.reset_default_graph()
# create a new graph
with tf.Session() as sess:
    print sess.run(...)

OP添加他的代码后编辑

with tf.name_scope("predict"):
    tf.reset_default_graph()

这是近似发生的事情。您的代码失败,因为tf.name_scope已经向图表中添加了某些内容。当你在“向图表中添加内容”时,你会告诉TF完全删除图表,但它不能,因为它正忙于添加内容。

答案 1 :(得分:1)

出于某种原因,我需要构建一个新的图表,我已经测试了,最终有效!非常感谢Salvador Dali的回答: - )

import tensorflow as tf
from my_models import Classifier

for i in range(10):
    tf.reset_default_graph()
    # build the graph
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
    classifier = Classifier(global_step)
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        print("do sth here.")

答案 2 :(得分:0)

简单地说, 用于清除以前使用sess.run()创建的占位符

答案 3 :(得分:0)

随着TensorFlow 2.0的发布,现在最好使用tf.compat.v1.reset_default_graph()以避免收到警告。链接到文档:https://www.tensorflow.org/api_docs/python/tf/compat/v1/reset_default_graph