每当我尝试使用tf.reset_default_graph()
时,我都会收到此错误:IndexError: list index out of range
或``。我应该在哪部分代码中使用它?我什么时候应该使用它?
编辑:
我更新了代码,但错误仍然存在。
def evaluate():
with tf.name_scope("loss"):
global x # x is a tf.placeholder()
xentropy = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=neural_network(x))
loss = tf.reduce_mean(xentropy, name="loss")
with tf.name_scope("train"):
optimizer = tf.train.AdamOptimizer()
training_op = optimizer.minimize(loss)
with tf.name_scope("exec"):
with tf.Session() as sess:
for i in range(1, 2):
sess.run(tf.global_variables_initializer())
sess.run(training_op, feed_dict={x: np.array(train_data).reshape([-1, 1]), y: label})
print "Training " + str(i)
saver = tf.train.Saver()
saver.save(sess, "saved_models/testing")
print "Model Saved."
def predict():
with tf.name_scope("predict"):
tf.reset_default_graph()
with tf.Session() as sess:
saver = tf.train.import_meta_graph("saved_models/testing.meta")
saver.restore(sess, "saved_models/testing")
output_ = tf.get_default_graph().get_tensor_by_name('output_layer:0')
print sess.run(output_, feed_dict={x: np.array([12003]).reshape([-1, 1])})
def main():
print "Starting Program..."
evaluate()
writer = tf.summary.FileWriter("mygraph/logs", tf.get_default_graph())
predict()
如果我从更新后的代码中删除了tf.reset_default_graph(),我会收到此错误:ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used
根据我目前的理解,tf.reset_default_graph()会删除所有图表,因此我避免了上面提到的错误(ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used
)
答案 0 :(得分:14)
这可能就是你使用它的方式:
import tensorflow as tf
a = tf.constant(1)
with tf.Session() as sess:
tf.reset_default_graph()
您收到错误是因为您在会话中使用它。来自tf.reset_default_graph()
文档:
当我在jupyter笔记本中进行实验时,在tf.Session或tf.InteractiveSession时调用此函数 active将导致未定义的行为。使用任何以前创建的 调用此函数后,tf.Operation或tf.Tensor对象将会 导致未定义的行为
tf.reset_default_graph()
在测试阶段可能会有所帮助(至少对我而言)。但是,我从来没有在生产中使用它,也没有看到它会如何有用。
以下是可能在笔记本中的示例:
import tensorflow as tf
# create some graph
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(...)
现在我不再需要这些东西,但是如果我创建另一个图形并在tensorboard中可视化,我会看到旧节点和新节点。为了解决这个问题,我可以重启内核并只运行下一个单元。但是,我可以这样做:
tf.reset_default_graph()
# create a new graph
with tf.Session() as sess:
print sess.run(...)
OP添加他的代码后编辑:
with tf.name_scope("predict"):
tf.reset_default_graph()
这是近似发生的事情。您的代码失败,因为tf.name_scope
已经向图表中添加了某些内容。当你在“向图表中添加内容”时,你会告诉TF完全删除图表,但它不能,因为它正忙于添加内容。
答案 1 :(得分:1)
出于某种原因,我需要构建一个新的图表,我已经测试了,最终有效!非常感谢Salvador Dali的回答: - )
import tensorflow as tf
from my_models import Classifier
for i in range(10):
tf.reset_default_graph()
# build the graph
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
classifier = Classifier(global_step)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print("do sth here.")
答案 2 :(得分:0)
简单地说, 用于清除以前使用sess.run()创建的占位符
答案 3 :(得分:0)
随着TensorFlow 2.0的发布,现在最好使用tf.compat.v1.reset_default_graph()
以避免收到警告。链接到文档:https://www.tensorflow.org/api_docs/python/tf/compat/v1/reset_default_graph