在Tensorflow中将CNN输出传递给LSTM?

时间:2018-03-07 08:03:52

标签: tensorflow lstm convolution

鉴于CNN的输出形状为[batch_size, height, width, number_of_channels](假设格式为channels_last),我有这种方式将CNN维度转换为RNN维度:

def collapse_to_rnn_dims(inputs):
    batch_size, height, width, num_channels = inputs.get_shape().as_list()
    if batch_size is None:
        batch_size = -1
    return tf.reshape(inputs, [batch_size, width, height * num_channels])

确实有效。但是,我想问一下这是否真的是重塑CNN输出的正确方法,以便将它们传递给LSTM层。

1 个答案:

答案 0 :(得分:2)

我找到了一个答案here,它完全符合我为手写文本识别所做的工作,尽管这个假设number_of_time_steps(宽度)是动态的而不是batch_size

shape = cnn_net.get_shape().as_list()  # [batch, height, width, features]
transposed = tf.transpose(cnn_net, perm=[0, 2, 1, 3],
                          name='transposed')  # [batch, width, height, features]
conv_reshaped = tf.reshape(transposed, [shape[0], -1, shape[1] * shape[3]],
                           name='reshaped')  # [batch, width, height x features]