考虑一个例子,考虑在张量流中计算内积。当我用一个使用feed的会话对它进行评估时,我试图尝试在TensorFlow中用图表引用事物的不同方法。请考虑以下代码:
import numpy as np
import tensorflow as tf
M = 4
D = 2
D1 = 3
x = tf.placeholder(tf.float32, shape=[M, D], name='data_x') # M x D
W = tf.Variable( tf.truncated_normal([D,D1], mean=0.0, stddev=0.1) ) # (D x D1)
b = tf.Variable( tf.constant(0.1, shape=[D1]) ) # (D1 x 1)
inner_product = tf.matmul(x,W) + b # M x D1
with tf.Session() as sess:
sess.run( tf.initialize_all_variables() )
x_val = np.random.rand(M,D)
#print type(x.name)
#print x.name
name = x.name
ans = sess.run(inner_product, feed_dict={name: x_val})
ans = sess.run(inner_product, feed_dict={x.name: x_val})
ans = sess.run(inner_product, feed_dict={x: x_val})
name_str = unicode('data_x', "utf-8")
ans = sess.run(inner_product, feed_dict={"data_x": x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={'data_x': x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={name_str: x_val}) #doesn't work
print ans
以下工作:
ans = sess.run(inner_product, feed_dict={name: x_val})
ans = sess.run(inner_product, feed_dict={x.name: x_val})
ans = sess.run(inner_product, feed_dict={x: x_val})
但是最后三个:
name_str = unicode('data_x', "utf-8")
ans = sess.run(inner_product, feed_dict={"data_x": x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={'data_x': x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={name_str: x_val}) #doesn't work
别。我检查了为什么类型x.name
,但它仍然无法工作,即使我将它转换为类型python解释器说它是。我documentation似乎说密钥必须是张量。但是,它接受x.name
而不是张量(它是<type 'unicode'>
),有人知道最近发生了什么吗?
我可以粘贴文档说它需要是一个张量:
可选的feed_dict参数允许调用者覆盖 图中张量的值。 feed_dict中的每个键都可以是其中之一 以下类型:
如果键是Tensor,则值可能是Python标量,字符串, list或numpy ndarray,可以转换为与该dtype相同的dtype 张量。另外,如果键是占位符,则形状为 将检查value是否与占位符兼容。如果 key是一个SparseTensor,值应该是SparseTensorValue。每 feed_dict中的值必须可以转换为dtype的numpy数组 相应的密钥。
答案 0 :(得分:1)
TensorFlow主要期望tf.Tensor
个对象作为提要词典中的键。如果它等于会话图中某些bytes
的{{1}}属性,它也会接受一个字符串(可能是unicode
或.name
)。
在您的示例中,tf.Tensor
有效,因为x.name
是x
并且您正在评估其tf.Tensor
属性。 .name
不起作用,因为它是"data_val"
(即tf.Operation
)的名称,而不是x.op
的名称,tf.Tensor
是{{1}的输出}}。如果您打印tf.Operation
,则会看到它的值为x.name
,这意味着“"data_val:0"
的第0个输出称为tf.Operation
。