TensorFlow LSTM实现提供有状态和无状态变体。默认情况下,使用无状态变量,该变量在一批训练后会自动重置LSTM状态。给定LSTM网络的思想,输入数据可能是应该学习的某种序列或序列。据我了解,状态通常应在一个(多个)完整序列中完全重置。 LSTM的输入大小为
batch_input_shape=(batch_size, num_timesteps, num_features)
,因此(据我了解),一个系列的长度应为batch_size * num_timesteps
。另外,当整个系列太大而无法容纳一批时,可以训练完一个系列的所有批次后,使用有状态变量手动重置。
现在,我有很多系列,每个系列都在各自的.tfrecord
文件中,并且我目前正在训练无状态变体。但是,我不确定如何在TensorFlow中处理该系列。使用
ds = tf.data.TFRecordDataset(files)
和一些预处理,我可以从这些文件创建一个数据集,然后使用
进行训练model.fit(ds, epochs=num_epochs, shuffle=False)
但是,据我了解,数据集不再区分单个文件,而是将所有文件串联在一起。这意味着即使对于不同长度的序列,训练也始终使用相同的批处理大小(即不使用填充),并可能混淆不同的序列。
因此,我的想法是在使用有状态LSTM的同时手动迭代文件,以较小的批数进行训练,并在每个序列后手动重置状态。我如何迭代TFRecordDataset
的各个文件。这可能吗?有什么更好的方法吗?