@InjectMocks
作为Tensorflow的新手,我很难理解在此框架中如何管理占位符。
如果我是第一次运行上面的代码,它会返回9(正确的值)。 但是如果我在同一个jupyter会话中再次运行它,我会得到以下错误。就好像全局变量(在这种情况下是占位符)没有得到清理,尽管我正在使用""关闭会议
堆栈追踪:
import tensorflow as tf
y_hat = tf.constant(36, name='y_hat') # Define y_hat constant. Set to 36.
yy = tf.placeholder(tf.int32, shape=[])
loss = tf.Variable((yy - y_hat)**2, name='loss') # Create a variable for the loss
init = tf.global_variables_initializer()
with tf.Session() as session:
session.run(tf.global_variables_initializer(), feed_dict = {yy: 39})
print(session.run(loss, feed_dict={yy: 39}))
知道发生了什么以及如何解决这个问题? 感谢
答案 0 :(得分:2)
在tf.reset_default_graph()
下方添加行import tensorflow as tf
,以便每次运行代码时都会重置张量流图。那你就不会得到这个错误。
顺便说一句,您实际上不需要将loss
指定为变量。你可以运行
import tensorflow as tf
y_hat = tf.constant(36, name='y_hat')
yy = tf.placeholder(tf.int32, shape=[])
loss = (yy - y_hat)**2
with tf.Session() as session:
print(session.run(loss, feed_dict={yy: 39}))
以上代码打印9。