如何将使用TimeseriesGenerator的代码迁移到Dataset API?

时间:2020-02-18 11:15:30

标签: tensorflow keras tensorflow2.0 tensorflow-datasets

在使用Tensorflow 1.x时,我们有以下内容:

X_train, Y_train = get_data('train')
gen = TimeseriesGenerator(X, Y, spec_w, batch_size=batch_size)
model.fit(gen, ...)

X的形状为(n_examples,n_bins)np.float32和 Y的形状为(n_examples,n_predictions)np.bool

我将示例转换为TFrecord。 每个示例中的特征为(X [n],Y [n]):

{
  'bins': tf.io.FixedLenFeature([n_bins], tf.float32),
  'predictions': tf.io.FixedLenFeature([n_predictions], tf.int64),
}

我正在尝试使用window / flat_map方法,但是我无法使结果数据集具有正确的形状,因为只有输入向量应具有形状(?,spec_w,n_bins),而预测向量应为形状(?,n_predictions)。我需要以某种方式将数据集分为输入和预测,对输入进行window和flat_map,然后压缩预测。

这是我当前的(错误的)代码:

    ds = tf.data.TFRecordDataset(kind + '.tfrecord')
    ds = ds.batch(batch_size)
    ds = ds.map(_parse_batch)
    ds = ds.window(spec_w, 1).flat_map(
        lambda *x: tf.data.Dataset.zip(tuple(col.batch(spec_w)
                                             for col in x)))
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

这个问题讨论了一个类似的问题: Sliding window of a batch in Tensorflow using Dataset API

0 个答案:

没有答案