BasicRNNCell对象不可迭代

时间:2017-05-18 11:56:59

标签: tensorflow

我试图在Tensorflow中构建情绪分析模型

def rnn_lstm(weights, biases, data_x, sequence_length, vocab_size, embedding_size):
    # Use Tensor Flow embedding lookup and convert the input data set
    with tf.device("/cpu:0"):
        embedding = tf.get_variable("embedding43", [vocab_size, embedding_size])
        embedded_data = tf.nn.embedding_lookup(embedding, data_x)
        embedded_data_dropout = tf.nn.dropout(embedded_data, rnn_dropout_keep_prob)

    #add LSTM cell and dropout nodes
    rnn_lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(rnn_cell_size, forget_bias = rnn_lstm_forget_bias)
    rnn_lstm_cell = tf.contrib.rnn.core_rnn_cell.DropoutWrapper(rnn_lstm_cell, output_keep_prob = rnn_dropout_keep_prob)



    rnn_data_X = embedded_data_dropout
    # Permuting batch_size and sequence_length
    rnn_data_X = tf.transpose(rnn_data_X, [1, 0, 2])
    #print ("RNN After transpose rnn_data_X: ", rnn_data_X)
    # Reshaping to (sequence_length * batch_size, rnn_data_vec_size)
    rnn_data_X = tf.reshape(rnn_data_X, [-1, rnn_data_vec_size])
    #print ("RNN After reshape rnn_data_X: ", rnn_data_X)
    # Split to get a list of 'sequence_length' tensors of shape (batch_size, rnn_data_vec_size)
    rnn_data_X = tf.split(rnn_data_X,sequence_length,0)
    #print ("RNN After split len(rnn_data_X): ", len(rnn_data_X), rnn_data_X[0])

    # Get lstm cell output
    outputs, states = tf.contrib.rnn.core_rnn_cell.BasicRNNCell(rnn_lstm_cell, rnn_data_X)


    output = tf.matmul(outputs[-1], weights) + biases
    return output

但是向我抛出一个错误,即BasicRNNCell对象不可迭代。请知道

1 个答案:

答案 0 :(得分:0)

问题在于:

# Get lstm cell output
outputs, states = tf.contrib.rnn.core_rnn_cell.BasicRNNCell(rnn_lstm_cell, rnn_data_X)

这不是你应该如何使用复发细胞。 rnn_lstm_cell已经是(一种)复发细胞;要使用它,您需要致电tf.nn.dynamic_rnn

# Get lstm cell output
outputs, states = tf.nn.dynamic_rnn(rnn_lstm_cell, rnn_data_X)

您可以在TensorFlow here中了解有关周期性模型的更多信息。