如何在没有固定batch_size的情况下设置Tensorflow dynamic_rnn,zero_state?

时间:2017-07-12 04:33:34

标签: python tensorflow

根据Tensorflow的官方网站(https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicLSTMCell#zero_state) zero_state必须指定batch_size。 我找到的很多例子都使用这段代码:

    init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)

    outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, 
        initial_state=init_state, time_major=False)

对于培训步骤,可以修复批量大小。但是,在预测时,测试集的形状可能与训练集的批量大小不同。 例如,我的一批训练数据具有形状[100,255,128]。批量大小为100,包含255个步骤和128个输入。而测试集是[2000,255,128]。我无法预测,因为在dynamic_rnn(initial_state)中,它已经设置了一个固定的batch_size = 100。 我该如何解决这个问题?

感谢。

3 个答案:

答案 0 :(得分:11)

您可以将batch_size指定为占位符,而不是常量。只需确保在feed_dict中提供相关的数字,这对于培训和测试来说会有所不同

重要的是,请将[]指定为占位符的维度,因为如果您指定None,则可能会出现错误,这在其他地方也是如此。所以这样的事情应该有效:

batch_size = tf.placeholder(tf.int32, [], name='batch_size')
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, 
        initial_state=init_state, time_major=False)
# rest of your code
out = sess.run(outputs, feed_dict={batch_size:100})
out = sess.run(outputs, feed_dict={batch_size:10})

显然,请确保批处理参数与输入的形状相匹配,如果dynamic_rnn设置为[batch_size, seq_len, features][seq_len, batch_size, features]将解释为time_majorTrue }

答案 1 :(得分:0)

有一个相当简单的实现。只需删除initial_state即可!这是因为初始化过程可能会预先分配一个批处理大小的内存。

答案 2 :(得分:0)

正如@陈狗蛋回答的那样,无需在initial_state中设置tf.compat.v1.nn.dynamic_rnn,因为它是可选的。您可以简单地这样做

outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, 
                                         X_inputs, 
                                         time_major=False, 
                                         dtype=tf.float32)

不要忘记设置dtype,这里我设置了tf.float32,您可以根据需要设置dtype

正如tf.compat.v1.nn.rnn_cell.LSTMCell的{​​{3}}所说:

batch_size:表示批次大小的整数,浮点数或单位张量

batch_size必须是一个明确的值。因此,对batch_size参数使用占位符是一种解决方法,但不建议使用。我建议您不要使用它,因为在将来的版本中它可能是无效的方法。