拒绝行数小于设置大小的批次

时间:2019-05-03 17:03:01

标签: python tensorflow machine-learning artificial-intelligence

我有一个使用tf.data.experimental.group_by_window的tensorflow输入函数来基于特定的键值创建批处理。

模型(自定义估算器)有点特殊,因为它要求批处理中的行数必须高于设置的最小值,即大于250。

在大多数情况下,这很好,因为给定键的行数非常大,但是偶尔会失败,并且返回的批处理少于最小行数。

在输入函数或模型中是否可以拒绝行数少于要求的批次?如果可以的话,我该怎么办?

group_key = 'key'
    dataset = dataset.apply(
        tf.data.experimental.group_by_window(
        key_func=lambda x: tf.to_int64(x[group_key]),
        reduce_func=lambda _, x: x.batch(batch_size),
        window_size=batch_size))

其中batch_size作为超参数提供,并且大于批处理的最小大小。

0 个答案:

没有答案