Tensorflow 2.2:数据集随着批量大小的增加

时间:2020-06-30 09:37:52

标签: tensorflow keras

我有一个自定义模型来训练文本分类器。在我们的代码中,我们改写了tf.keras.models.Model的所有重要方法,例如fitpredict等。其原因之一是我们希望在每个时期都适应批量大小。因此,在第一个时期,例如,批次大小为16,在下一次迭代中,批次大小将增加为20,依此类推。直到达到上限为止。

这是生成我们的tf.Dataset的方法。我们将为每个时期调用此方法,以获取具有所需批处理大小的数据集。

def as_tf_dataset(
    self, batch_size: int, batch_strategy: Text = SEQUENCE, shuffle: bool = False
) -> tf.data.Dataset:
    """Create tf dataset."""

    shapes, types = self._get_shapes_types()

    return tf.data.Dataset.from_generator(
        lambda batch_size_: self._gen_batch(batch_size_, batch_strategy, shuffle),
        output_types=types,
        output_shapes=shapes,
        args=([batch_size]),
    )

使用Tensorflow 2.2。可以只覆盖train_steptest_step而不是fit等。这将大大简化我们的代码。但是,我找不到一种方法来保持批次增加的数量。

有人对如何解决这个问题有想法吗?还是我们需要保留我们的自定义fit方法才能实现这一目标?

0 个答案:

没有答案