Tensorflow dynamic_rnn输入等级错误

时间:2017-10-07 22:12:51

标签: tensorflow recurrent-neural-network

所以,我试图在tensorflow中使用rnn来生成文本。但是,一旦我从static_rnn切换到dynamic_rnn,我就会收到此错误:

File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_shape.py", line 654, in with_rank_at_least
    raise ValueError("Shape %s must have rank at least %d" % (self, rank))
ValueError: Shape (100, 5) must have rank at least 3

这是生成错误的代码的一部分:

inputs_series = self.input_layer()
with tf.variable_scope(constants.HIDDEN):
    self.hidden_state_placeholder = tf.placeholder(
        dtype=tf.float32, 
        shape=[self.settings.train.batch_size, self.settings.rnn.hidden_size],
        name="hidden_state_placeholder")
    cell = tf.contrib.rnn.GRUCell(self.settings.rnn.hidden_size)
    states_series, self.current_state = tf.nn.dynamic_rnn(
        cell=cell, 
        inputs=inputs_series,
        initial_state=self.hidden_state_placeholder)

inputs_series的形状为:(10,5,100),对应于(截断文本长度,批量大小,类数)

hidden_state_placeholder的形状为(5,100)(批量大小,隐藏状态大小),但即使我没有提供初始状态,错误仍然存​​在。

张量流版本是1.3,如果它有帮助。

任何见解都将不胜感激!

1 个答案:

答案 0 :(得分:0)

默认情况下,time_major == False中有tf.nn.dynamic_rnn,但inputs_seriestime_major == True。所以也许可以将最后一个语句改为

states_series, self.current_state = tf.nn.dynamic_rnn(
    cell=cell, 
    inputs=inputs_series,
    initial_state=self.hidden_state_placeholder,
    time_major=True)