如何在seq2seq任务中将卷积层与lstm层连接到?

时间:2019-08-07 19:39:28

标签: tensorflow conv-neural-network lstm seq2seq

seq2seq任务是从视频数据中识别句子(也称为纯视觉语音识别/唇读)。

该模型由卷积层和lstm层组成。但是,卷积层的输出为 [batch_size, height, width, channel_size] 的形式;而lstm层的输入必须为 [batch_size, n_steps, dimension] 的形状。

工作流程如下:

  • 首先,数据组织为[batch_size,n_steps,高度,宽度,channel_size]。
  • 然后我将其重塑为[batch_size*n_steps, height, width, channel_size]并将其输入转换层。
  • 转换层的输出为[batch_size*n_steps, height', width', channel_size']
  • 我当然可以将其重塑为[batch_size, n_steps, height', width', channel_size'],但是如何将其输入lstm层,该层需要数据为[batch_size, n_steps, dimension]

我不知道仅将轴[height', width', channel_size']整形为[dimension]的一个轴是否适合此仅视觉语音识别任务。

提示:

1 个答案:

答案 0 :(得分:1)

RNN希望输入将是连续的。因此,输入的形状为[time, feature_size],或者如果您正在处理批处理[batch_size, time, feature_size]

在您的情况下,输入的形状为[batch_size, number_of_frames, height, width, num_channels]。然后,使用卷积层来了解每个视频帧中像素之间的空间依赖性。因此,对于每个视频帧,卷积层将为您提供形状为[activation_map_width, activation_map_height, number_of_filters]的张量。然后,由于您要学习框架的上下文相关表示,因此可以安全地重塑一维序列中每个框架所学到的一切。

最后,您将提供RNN:[b_size, num_frames, am_width * am_height * num_filters]

对于实现,如果我们假设您有2个视频,并且每个视频有5帧,其中每帧的宽度和高度分别为10和3个频道,那么您应该这样做:

# Batch of 2 videos with 7 frames of size [10, 10, 3]
video = np.random.rand(2, 7, 10, 10, 3).astype(np.float32)
# Flattening all the frames
video_flat = tf.reshape(video, [14, 10, 10, 3])
# Convolving each frame
video_convolved = tf.layers.conv2d(video_flat, 5, [3,3])
# Reshaping the frames back into the corresponding batches
video_batch = tf.reshape(video_convolved, [2, 7, video_convolved.shape[1], video_convolved.shape[2], 5])
# Combining all learned for each frame in 1D
video_flat_frame = tf.reshape(video_batch, [2, 7, video_batch.shape[2] * video_batch.shape[3] * 5])
# Passing the information for each frame through an RNN
outputs, _ = tf.nn.dynamic_rnn(tf.nn.rnn_cell.LSTMCell(9), video_flat_frame, dtype=tf.float32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # Output where we have a context-dependent representation for each video frame
    print(sess.run(outputs).shape)

请注意,为简单起见,我已经在代码中对一些变量进行了硬编码。

希望对您有所帮助!