如何将`tf.nn.dynamic_rnn`与非rnn组件一起使用

时间:2018-06-11 16:09:57

标签: python tensorflow recurrent-neural-network

我有一个架构在进入RNN之前使用编码器。编码器输入形状为[batch, height, width, channels],RNN输入为形状[batch, time, height, width, channels]。我想将编码器的输出直接提供给RNN,但这会造成内存问题。我必须一次将batch*time ~= 3*100(通过重新整形)图像输入编码器。我知道tf.nn.dynamic_rnn可以利用swap_memory,我也想在编码器中利用它。以下是一些精简代码:

#image inputs [batch, time, height, width, channels]
inputs = tf.placeholder(tf.float32, [batch, time, in_sh[0], in_sh[1], in_sh[2]])

#This is where the trouble starts
#merge batch and time
inputs = tf.reshape(inputs, [batch*time, in_sh[0], in_sh[1], in_sh[2]])
#build the encoder (and get shape of output)
enc, enc_sh = build_encoder(inputs)
#change back to time format
enc = tf.reshape(enc, [batch, time, enc_sh[0], enc_sh[1], enc_sh[2]])

#build rnn and get initial state (zero_state)
rnn, initial_state = build_rnn()
#use dynamic unrolling
rnn_outputs, rnn_state = tf.nn.dynamic_rnn(
        rnn, enc,
        initial_state=initial_state,
        swap_memory=True,
        time_major=False)

我正在使用的当前方法是先验地在所有图像上运行编码器(并保存到光盘),但我想执行数据集扩充(对图像),这在提取特征后是不可能的。< / p>

1 个答案:

答案 0 :(得分:0)

对于遇到此问题的其他任何人。我做了一个从/foo/foosb1other/foo/bar/foosb2bar派生的包装程序,该包装程序完成了我所需要的。 RNNCell是一个使用输入构建子图并返回输出张量的函数。不幸的是,必须知道输出形状(至少我无法使其正常工作)。

model_fn