它是什么意思"输入必须是一个列表"?

时间:2016-09-22 05:14:08

标签: tensorflow

下面的代码显示我"输入必须是一个列表"。在这。

outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

当我为输入x定义占位符时。我已经将形状设置为[None,None]。我认为这个形状是二维数组。但是,代码不断要求列表类型为x

下面,我在训练前附上了我的所有代码。并且这些代码被插入到类的功能中。

x = tf.placeholder("float",[None,None])
y = tf.placeholder("float",[None])

lstm_cell = rnn_cell.BasicLSTMCell(self.n_hidden, forget_bias=1.0)

outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

pred = tf.matmul(outpus[-1], self.weights['out']) + self.biases['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred,y))
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(cost)

correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

init = tf.initialize_all_variables()

self.sess = tf.Session()

self.sess.run(init)

此外,实际输入将是单词序列的浮动和标签的浮动,形成为x=[["aaa","aaa","aaa"],["bbb","bbb"]]y=["c1","c2"]

此时,x的第一个元素数组标有" c1"第二个是" c2"。特别是,x的每个元素数组的大小不能是确定性的。

1 个答案:

答案 0 :(得分:0)

正如documentation所述,函数inputs的参数tf.nn.rnn()是:

  

输入:输入的长度为T的列表,每个都是形状张量[batch_size,input_size]或这些元素的嵌套元组。

在您的代码中,参数inputsx Tensor 形状[None, None]的占位符。为了使您的代码有效,x必须是形状[None, input_lenght]的T张量列表。

以下代码生成张量列表inputs,因此函数tf.nn.rnn有效。

import tensorflow as tf

x = tf.placeholder("float",[None,16])
y = tf.placeholder("float",[None])

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(256, forget_bias=1.0)

inputs = []
for t in range(10):
        inputs.append(x)

print(len(inputs))

outputs, states = tf.nn.rnn(lstm_cell, inputs, dtype=tf.float32)

pred = tf.matmul(outputs[-1], self.weights['out']) + self.biases['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred,y))
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(cost)

correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

init = tf.initialize_all_variables()

self.sess = tf.Session()

self.sess.run(init)

请注意占位符x的定义形状为[None, input_shape]。它不能使用形状[None, None],因为第一个维度是batch_size,可以是None,但第二个维度是输入序列中每个项目的大小,并且该值不能None