Tensorflow RNN示例仅限于固定批量大小?

时间:2017-04-03 14:40:34

标签: machine-learning tensorflow deep-learning recurrent-neural-network

在Tensorflow上查看RNN example时遇到初始状态构造问题。在图形的构建时,我们将图形限制为仅处理一个批量大小的输入。这对我来说是一个问题,因为我希望能够提供一个示例并获得该单个示例的预测。

限制此问题的代码部分是:

initial_state = state = tf.zeros([batch_size, lstm.state_size])

所以我的问题是如何扩展示例以便我可以使用变量批量大小,以便我可以使用相同的模型进行批量大小的训练,然后使用单个示例进行预测?

2 个答案:

答案 0 :(得分:2)

这就是我这样做的方式。您可以将df.Value.idxmax() 作为变量传递:

batch_size

其中batch_size = tf.placeholder(tf.int32) init_state = cell.zero_state(batch_size, tf.float32) 是RNN单元格之一(cellBasicLSTMCellBasicGRUCell等。但是,如果您保留多个批次的状态,那么由于其大小必须保持不变,这将不起作用。

答案 1 :(得分:0)

Tensorflow文本生成教程介绍了如何执行此操作(现在为TF 2.0)。看来batch_size已成为构建模型的一部分,因此您必须使用新的批次大小从保存的重量中重建/重新加载:

https://www.tensorflow.org/tutorials/text/text_generation#restore_the_latest_checkpoint

  

为简化此预测步骤,请使用1的批量大小。

     

由于RNN状态从一个时间步到另一个时间步的传递方式,   该模型一旦建立,就只接受固定的批量大小。

     

要使用其他batch_size运行模型,我们需要重建   建模并从检查点恢复权重。

model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
model.summary()

我不确定为什么要这样做,但我一直认为这是因为为循环层进行批处理需要管理多个并行的隐藏状态管道,因此它会对其进行预分配。