我正在尝试熟悉TensorFlow,而且我不确定占位符,变量等。为了简单起见,我尝试创建一个非常简单的计算 - 占位符和变量,它只是占位符乘以2。
我把所有东西放在一个函数中,就像这样:
import tensorflow as tf
def try_variable(value):
x = tf.placeholder(tf.float64, name='x')
v = tf.Variable(x * 2, name='v', validate_shape=False)
with tf.Session() as session:
init = tf.global_variables_initializer()
session.run(init, feed_dict={x: value})
return session.run(v)
然后我调用函数:
print(try_variable(80))
确实输出是160。
但是当我再次打电话时:
print(try_variable(80))
我收到错误:
InvalidArgumentError:您必须使用dtype double为占位符张量'x'提供值
我错过了什么?
答案 0 :(得分:4)
现在你每次调用函数时都会创建一个新的变量和占位符,所以第二次调用try_variable
函数时,你实际上有2个占位符和2个TensorFlow变量! x
,x_1
,v
,v_1
。
因此,在第二次运行init操作时,只为占位符x_1
提供初始值,该占位符现已绑定到python变量x
。
如果要在当前图表中打印所有张量的名称,可以调用
print [n.name for n in tf.get_default_graph().as_graph_def().node]
如果您仍希望每次调用该函数时创建2个新的张量,一个选项是使用命令tf.reset_default_graph()
重置默认图形
每次调用该函数时 - 都是非常不推荐的。