Tensor Flow - LSTM - ' Tensor'对象不可迭代

时间:2016-11-07 11:33:12

标签: python neural-network tensorflow lstm

您好我正在为lstm rnn cell使用以下功能。

def LSTM_RNN(_X, _istate, _weights, _biases):
    # Function returns a tensorflow LSTM (RNN) artificial neural network from given parameters. 
    # Note, some code of this notebook is inspired from an slightly different 
    # RNN architecture used on another dataset: 
    # https://tensorhub.com/aymericdamien/tensorflow-rnn

    # (NOTE: This step could be greatly optimised by shaping the dataset once
    # input shape: (batch_size, n_steps, n_input)
    _X = tf.transpose(_X, [1, 0, 2])  # permute n_steps and batch_size

    # Reshape to prepare input to hidden activation
    _X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)

    # Linear activation
    _X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']

    # Define a lstm cell with tensorflow
    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)


    # Split data because rnn cell needs a list of inputs for the RNN inner loop
    _X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)

    # Get lstm cell output
    outputs, states = rnn.rnn(lstm_cell, _X, initial_state=_istate)

    # Linear activation
    # Get inner loop last output
    return tf.matmul(outputs[-1], _weights['out']) + _biases['out']

函数的输出存储在pred变量下。

pred = LSTM_RNN(x, istate, weights, biases)

但它显示以下错误。 (表明张量对象不可迭代。)

以下是ERROR图片链接 - http://imgur.com/a/NhSFK

请帮助我,如果这个问题看起来很愚蠢,我很抱歉,因为我对lstm和tensor流程库很新。

感谢。

2 个答案:

答案 0 :(得分:8)

当它尝试使用语句state解包c, h=state时发生错误。根据您使用的tensorflow版本(您可以通过在python解释器中键入import tensorflow; tensorflow.__version__来检查版本信息),在r0.11之前的版本中,初始化时state_is_tuple参数的默认设置rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)设置为False。请参阅此处的documentation

BasicLSTMCell documentation in r0.10

由于tensorflow版本r0.11(或主版本),state_is_tuple的默认设置被设置为True。请在此处查看documentation

BasicLSTMCell documentation in r0.11

如果您安装了r0.11或tensorflow的主版本,请尝试将BasicLSTMCell初始化行更改为: lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=False)。您遇到的错误应该消失。虽然,他们的页面确实表示state_is_tuple=False行为很快就会被弃用。

BasicLSTMCell state_is_tuple argument documentation

答案 1 :(得分:3)

我碰巧同时遇到了同样的问题。 我只是描述一下可能对你有帮助的情况

它表示如下

c1_ex, T1_ex = tf. ones(10,tf. int 32)
 raise Type Error ...

我发现' ='的左侧已预先设置了两个矢量名称

而另一方只返回一个向量

抱歉我的英语效率低下

你的问题实际上出现在第146行而不是第193行