我有一个自定义模型来训练文本分类器。在我们的代码中,我们改写了tf.keras.models.Model
的所有重要方法,例如fit
,predict
等。其原因之一是我们希望在每个时期都适应批量大小。因此,在第一个时期,例如,批次大小为16,在下一次迭代中,批次大小将增加为20,依此类推。直到达到上限为止。
这是生成我们的tf.Dataset
的方法。我们将为每个时期调用此方法,以获取具有所需批处理大小的数据集。
def as_tf_dataset(
self, batch_size: int, batch_strategy: Text = SEQUENCE, shuffle: bool = False
) -> tf.data.Dataset:
"""Create tf dataset."""
shapes, types = self._get_shapes_types()
return tf.data.Dataset.from_generator(
lambda batch_size_: self._gen_batch(batch_size_, batch_strategy, shuffle),
output_types=types,
output_shapes=shapes,
args=([batch_size]),
)
使用Tensorflow 2.2。可以只覆盖train_step
和test_step
而不是fit
等。这将大大简化我们的代码。但是,我找不到一种方法来保持批次增加的数量。
有人对如何解决这个问题有想法吗?还是我们需要保留我们的自定义fit
方法才能实现这一目标?