鉴于CNN的输出形状为[batch_size, height, width, number_of_channels]
(假设格式为channels_last
),我有这种方式将CNN维度转换为RNN维度:
def collapse_to_rnn_dims(inputs):
batch_size, height, width, num_channels = inputs.get_shape().as_list()
if batch_size is None:
batch_size = -1
return tf.reshape(inputs, [batch_size, width, height * num_channels])
确实有效。但是,我想问一下这是否真的是重塑CNN输出的正确方法,以便将它们传递给LSTM层。
答案 0 :(得分:2)
我找到了一个答案here,它完全符合我为手写文本识别所做的工作,尽管这个假设number_of_time_steps
(宽度)是动态的而不是batch_size
shape = cnn_net.get_shape().as_list() # [batch, height, width, features]
transposed = tf.transpose(cnn_net, perm=[0, 2, 1, 3],
name='transposed') # [batch, width, height, features]
conv_reshaped = tf.reshape(transposed, [shape[0], -1, shape[1] * shape[3]],
name='reshaped') # [batch, width, height x features]