张量流 - 卷积+ lstm

时间:2017-10-08 21:04:52

标签: tensorflow

我正在尝试从本文https://arxiv.org/pdf/1411.4389.pdf

实现图1的架构

我已经关注了tensorflow的lstm教程,但这对我没有帮助,因为输入数据没有及时卷积。我需要在序列中的每个帧上运行CNN,然后将其提供给lstm。有人知道这个的任何示例代码吗?

1 个答案:

答案 0 :(得分:1)

我为多通道时间序列数据实现了这样的架构,通过及时展开卷积网络并将连接的输出张量馈送到LSTM网络中。 LSTM网络由标准tf.contrib.rnn.LSTMBlockCelltf.contrib.rnn.MultiRNNCelltf.nn.dynamic_rnn创建。

更好地展开时间的方法是使用tf.while_loop来创建网络的卷积部分。

以下代码说明了这个想法,但尚未经过测试。

conv_outputs = [None]*len(iterator)
with tf.variable_scope("ConvNet"):
    for idx, frame in enumerate(frames):
        # assuming the shared weight and bias variables were created beforehand
        for i, (weights, bias) in enumerate(conv_maps):
            with tf.name_scope("Conv{}".format(i)):
                conv_out = tf.add(
                    tf.nn.conv2d(
                        conv_in,  # [batch, in_height, in_width, in_channels]
                        weights,  # [filter_height, filter_width, in_channels, out_channels]
                        strides=[1, 1, 1, 1],
                        padding="VALID", name="Conv{}".format(i)),
                    bias, name="Add{}".format(i))
                conv_out = tf.nn.relu(conv_out, name="Relu{}".format(i))
            with tf.name_scope("Pool{}".format(i)):
                pool_out = tf.nn.max_pool(
                    conv_out,  # [batch, in_height, in_width, in_channels]
                    ksize=[1, pool_size, 1, 1],
                    strides=[1, pool_size, 1, 1],
                    padding='VALID', name="Pool{}".format(i))
           conv_outputs[idx] = pool_out
           conv_in = pool_out
    stacked = tf.stack(conv_outputs, 1)
    reshaped = tf.reshape(stacked, [-1, len(conv_outputs), conv_outputs[-1].shape[-1])  # [batch_size, num_frames, last_layer_feature_maps

num_lstm_cells = [8, 8]

cell_series = [tf.contrib.rnn.LSTMBlockCell(n) for n in num_lstm_cells]
layers = tf.contrib.rnn.MultiRNNCell(cell_series, state_is_tuple=True)

# prediction of the LSTM network for input batch_x
net_out, state = tf.nn.dynamic_rnn(layers, reshaped, dtype=tf.float32)