如何使用TFRecordDataset从不同长度的序列中训练TensorFlow LSTM

时间:2019-09-11 09:11:01

标签: tensorflow keras lstm tensorflow-datasets

TensorFlow LSTM实现提供有状态和无状态变体。默认情况下,使用无状态变量,该变量在一批训练后会自动重置LSTM状态。给定LSTM网络的思想,输入数据可能是应该学习的某种序列或序列。据我了解,状态通常应在一个(多个)完整序列中完全重置。 LSTM的输入大小为

batch_input_shape=(batch_size, num_timesteps, num_features)

,因此(据我了解),一个系列的长度应为batch_size * num_timesteps。另外,当整个系列太大而无法容纳一批时,可以训练完一个系列的所有批次后,使用有状态变量手动重置。

现在,我有很多系列,每个系列都在各自的.tfrecord文件中,并且我目前正在训练无状态变体。但是,我不确定如何在TensorFlow中处理该系列。使用

ds = tf.data.TFRecordDataset(files)

和一些预处理,我可以从这些文件创建一个数据集,然后使用

进行训练
model.fit(ds, epochs=num_epochs, shuffle=False)

但是,据我了解,数据集不再区分单个文件,而是将所有文件串联在一起。这意味着即使对于不同长度的序列,训练也始终使用相同的批处理大小(即不使用填充),并可能混淆不同的序列。

因此,我的想法是在使用有状态LSTM的同时手动迭代文件,以较小的批数进行训练,并在每个序列后手动重置状态。我如何迭代TFRecordDataset的各个文件。这可能吗?有什么更好的方法吗?

0 个答案:

没有答案