在v0.12.0中使用BasicLSTMCell执行lstm时出错

时间:2017-01-05 22:19:54

标签: python tensorflow lstm

我的系统上安装了v0.12.0,在运行简单的LSTM时我遇到了一些问题。在尝试使其工作时,我甚至将代码缩减为https://www.tensorflow.org/tutorials/recurrent/处的示例,但仍面临同样的问题。我正在附加代码和错误日志中的代码段。 " lstm"函数从卷积函数中获取输入,该函数生成40帧序列的潜在表示(大小:1024)。

frames_batch_size = 40
batch_size = 20

def lstm(x, state_size=1024, initial_state=None, reuse=False):
    with tf.variable_scope("lstm") as lstm_scope:
        if reuse:
            tf.get_variable_scope().reuse_variables()

        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(state_size)
        state = initial_state if initial_state else tf.zeros([batch_size, state_size], name="lstm_state")

        print(x.name, "-", x.get_shape())
        print(state.name, "-", state.get_shape())

        for i in range(frames_batch_size):
            output, state = lstm_cell(x[:, i], state)

        print(output.name, output.get_shape())
    return output

根据输出的张量形状是:

generator/convolution/conv_output:0 - (20, 40, 1024)
generator/lstm/lstm_state:0 - (20, 1024)

错误日志的片段是:

<ipython-input-13-35b652ff4acc> in lstm(x, state_size, initial_state, reuse)
     10 
     11         for i in range(frames_batch_size):
---> 12             output, state = lstm_cell(x[:, i], state)
     13 
     14         print(output.name, output.get_shape())

/usr/local/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell.py in __call__(self, inputs, state, scope)
    306       # Parameters of gates are concatenated into one multiply for efficiency.
    307       if self._state_is_tuple:
--> 308         c, h = state
    309       else:
    310         c, h = array_ops.split(1, 2, state)

/usr/local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in __iter__(self)
    508       TypeError: when invoked.
    509     """
--> 510     raise TypeError("'Tensor' object is not iterable.")
    511 
    512   def __bool__(self):

TypeError: 'Tensor' object is not iterable.

1 个答案:

答案 0 :(得分:0)

由于state_is_tuple默认为True,因此您需要将lstm.zero_state(batch_size, tf.float32)传递给state变量。

替换,

state = initial_state if initial_state else tf.zeros([batch_size, state_size], name="lstm_state")

用,

state = initial_state if initial_state else lstm.zero_state(batch_size, tf.float32)

另外,请确保将LSTMStateTuple个对象传递给initial_state参数。