如何在Tensorflow中使用占位符从占位符列表中获取占位符?

时间:2018-03-14 06:36:00

标签: tensorflow

我有一个占位符列表如下:

input_vars = []
input_vars.append(tf.placeholder(shape=[None, 5], dtype=tf.float32, name="place0"))
input_vars.append(tf.placeholder(shape=[None, 5], dtype=tf.float32, name="place1"))
input_vars.append(tf.placeholder(shape=[None, 5], dtype=tf.float32, name="place2"))

我想基于int占位符访问不同的占位符,如下所示:

which_input = tf.placeholder(tf.int32)

在会话中调用以下内容时:

input_vars[which_input]

我收到以下错误:

  

TypeError:列表索引必须是整数,而不是Tensor

我尝试使用tf.gather,但是当我想在密集层中提供选定的占位符时,如下所示:

helpme = tf.gather(input_vars, which_input)
l_in = tf.layers.dense(inputs=helpme, units=64, activation=tf.nn.relu, trainable=True)

我收到以下错误:

  

ValueError:图层dense_4的输入0与图层不兼容:其排名未定义,但图层需要定义的排名。

这是会话运行信息:

x = [[1,2,3,4,5]]
x.append([6,7,8,9,10])

y = [[5,4,3,2,1]]
y.append([5,3,2,1,1])

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    dictd = dict()
    dictd[input_vars[0]] = x
    dictd[input_vars[1]] = y
    dictd[input_vars[2]] = x
    dictd[which_input] = 2

    print sess.run(l_in, feed_dict=dictd)

我错过了什么吗?怎么办呢?

1 个答案:

答案 0 :(得分:1)

您需要按照this回答中的说明重新设置tf.gather的输出:

l_in = tf.layers.dense(inputs=tf.reshape(helpme, shape=[-1,5]), units=64, activation=tf.nn.relu, trainable=True)