如何在简单的BasicRNN模型中处理可变长度输出序列?

时间:2017-10-02 13:36:35

标签: python tensorflow lstm rnn

我正在尝试在tensorflow中运行一个简单的 BasicRNN模型,具有不同的输入/输出大小。输入的形状应为 [batch_size,50,2] ,输出的形状应为 [batch_size,75,2]

在完成下面相同形状的示例(in_length = out_length)之后,我尝试将sequence_length参数添加到dynamic_rnn以预测更大的输出序列,但它似乎不起作用,因为我总是收到值错误喂食模型时

in_length = 50
out_length = 75
num_inputs = 2
num_outputs = 2
hidden = 64
batch_size = 1024

x = tf.placeholder(tf.float32, [None, out_length, num_inputs], name="Inputs")
y = tf.placeholder(tf.float32, [None, out_length, num_outputs], name="Labels")

basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=hidden, activation=tf.nn.relu)

rnn_output, _ = tf.nn.dynamic_rnn(basic_cell, 
                                  x, 
                                  sequence_length=[in_length]*batch_size,
                                  dtype=tf.float32)

stacked_rnn_output = tf.reshape(rnn_output, [-1, hidden])
stacked_outputs = tf.layers.dense(stacked_rnn_output, num_outputs)
outputs = tf.reshape(stacked_outputs, [-1, out_length, num_outputs])

...

batch = np.random.choice(N, batch_size, replace=False)
sess.run([training_op], feed_dict={x:x_batches[batch], y:y_batches[batch]})

值错误:

Traceback (most recent call last):
File "/home/timeseries/main.py", line 142, in <module>
    sess.run([training_op], feed_dict={x:x_batches[batch], y:y_batches[batch]})
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 778, in run
    run_metadata_ptr)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 961, in _run
    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))

ValueError: Cannot feed value of shape (1024, 50, 2) for Tensor u'Inputs:0', which has shape '(?, 100, 2)'

有人对我有暗示吗?在理解不同序列长度时是否存在一般错误?我猜,这些形状对于dynamic_rnn来说应该不是问题......

修改

我添加了追溯,虽然我不认为,这可能会有所帮助。这绝对是第二个维度的一个问题,我已经解决了,但它对我的模型来说不是一个正确的解决方案。

我通过用零填充输入来更新我的模型,所以我对输入和输出有相同的张量。该模型不会再出现值错误,但我希望我可以为网络提供50个实数值+25个零= 75并且将接收75个预测值,因为我正在使用序列长度参数并且我正在训练网络有75个真实标签。 但输出是50个实数值和25个零......

这应该是一个多对多/一对多的问题,我认为这对于一个人来说没问题......

0 个答案:

没有答案