dynamic_rnn的输出形状,time_major = True

时间:2017-07-06 17:58:09

标签: tensorflow rnn

我使用TensorFlow来实现RNN。我创建了这样的循环单位:

gru_cell = tf.contrib.rnn.GRUCell(16)
zero_state = gru_cell.zero_state(1, tf.float32)
initial_state = tf.placeholder(tf.float32, zero_state.get_shape())
out_tensor, final_state = tf.nn.dynamic_rnn(
    gru_cell,
    parent_tensor,
    initial_state=initial_state,
    time_major=False)
print(out_tensor.get_shape())

它将输出形状报告为(1, ?, 16),正如我所料。第二个维度为?,因为max_time未知。

现在我将其切换为time_major=True。根据文档,我希望只交换前两个轴,因此输出形状应为(?, 1, 16)。但它不是。相反,它是(1, 1, 16)。发生了什么事? max_time仍然未知,为什么它会硬编码为1?

1 个答案:

答案 0 :(得分:0)

来自Tesnorflow文档,https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn

enter image description here

根据上面的说法,现在你的max_time = 1,batch_size = ?, input_size = 16

现在根据文档, enter image description here

你的输出应该是(1,?,16)的形状。