您可以提供哪些数据类型作为TensorFlow中的输入键?

时间:2016-06-21 21:40:11

标签: tensorflow

考虑一个例子,考虑在张量流中计算内积。当我用一个使用feed的会话对它进行评估时,我试图尝试在TensorFlow中用图表引用事物的不同方法。请考虑以下代码:

import numpy as np
import tensorflow as tf

M = 4
D = 2
D1 = 3
x = tf.placeholder(tf.float32, shape=[M, D], name='data_x') # M x D
W = tf.Variable( tf.truncated_normal([D,D1], mean=0.0, stddev=0.1) ) # (D x D1)
b = tf.Variable( tf.constant(0.1, shape=[D1]) ) # (D1 x 1)
inner_product = tf.matmul(x,W) + b # M x D1
with tf.Session() as sess:
    sess.run( tf.initialize_all_variables() )
    x_val = np.random.rand(M,D)
    #print type(x.name)
    #print x.name
    name = x.name
    ans = sess.run(inner_product, feed_dict={name: x_val})
    ans = sess.run(inner_product, feed_dict={x.name: x_val})
    ans = sess.run(inner_product, feed_dict={x: x_val})
    name_str = unicode('data_x', "utf-8")
    ans = sess.run(inner_product, feed_dict={"data_x": x_val}) #doesn't work
    ans = sess.run(inner_product, feed_dict={'data_x': x_val}) #doesn't work
    ans = sess.run(inner_product, feed_dict={name_str: x_val}) #doesn't work
    print ans

以下工作:

ans = sess.run(inner_product, feed_dict={name: x_val})
ans = sess.run(inner_product, feed_dict={x.name: x_val})
ans = sess.run(inner_product, feed_dict={x: x_val})

但是最后三个:

name_str = unicode('data_x', "utf-8")
ans = sess.run(inner_product, feed_dict={"data_x": x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={'data_x': x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={name_str: x_val}) #doesn't work

别。我检查了为什么类型x.name,但它仍然无法工作,即使我将它转换为类型python解释器说它是。我documentation似乎说密钥必须是张量。但是,它接受x.name而不是张量(它是<type 'unicode'>),有人知道最近发生了什么吗?

我可以粘贴文档说它需要是一个张量:

  

可选的feed_dict参数允许调用者覆盖   图中张量的值。 feed_dict中的每个键都可以是其中之一   以下类型:

     

如果键是Tensor,则值可能是Python标量,字符串,   list或numpy ndarray,可以转换为与该dtype相同的dtype   张量。另外,如果键是占位符,则形状为   将检查value是否与占位符兼容。如果   key是一个SparseTensor,值应该是SparseTensorValue。每   feed_dict中的值必须可以转换为dtype的numpy数组   相应的密钥。

1 个答案:

答案 0 :(得分:1)

TensorFlow主要期望tf.Tensor个对象作为提要词典中的键。如果它等于会话图中某些bytes的{​​{1}}属性,它也会接受一个字符串(可能是unicode.name)。

在您的示例中,tf.Tensor有效,因为x.namex并且您正在评估其tf.Tensor属性。 .name不起作用,因为它是"data_val"(即tf.Operation)的名称,而不是x.op的名称,tf.Tensor是{{1}的输出}}。如果您打印tf.Operation,则会看到它的值为x.name,这意味着“"data_val:0"的第0个输出称为tf.Operation