TFRecords和Dataset API甚至可以用于迷你批次

时间:2018-12-22 23:37:41

标签: python tensorflow tensorflow-datasets

我大约有310万条记录,这些记录分为两个TFRecords文件。一个包含正类(〜217K),另一个包含负类(〜2.9MM)。我正在尝试使用Dataset API来交错记录,以使每个批次具有50/50的比例。为了用尽所有数据,我想重复一些肯定的例子,以便使用所有否定的例子。

现在最终发生的是它甚至开始,但是当肯定记录用完时,批处理中只会出现否定记录。

我相信如果文件名是.repeat(),可以在下面的代码中添加train_pos.tfrecords来解决此问题,但是,我无法弄清楚如何修改_get_files()函数所以。我认为这可能是我所缺少的简单答案?

files = tf.data.Dataset.list_files("train_*.tfrecords")       
def _get_files(x):
    return tf.data.TFRecordDataset(x).shuffle(buffer_size=10000)

dataset = files.apply(tf.contrib.data.parallel_interleave(
    lambda x: _get_files(x), cycle_length=2))\
    .batch(self.batch_size)\
    .map(_parse_line, num_parallel_calls=6)\
    .repeat(1)\
    .prefetch(2)

1 个答案:

答案 0 :(得分:0)

您可以通过两次调用相关的TF记录tf.data.Dataset来创建两个数据集:

files1 = tf.data.Dataset.list_files(...)
files2 = tf.data.Dataset.list_files(...)

,并使用repeat(-1)使两个数据集取之不尽。 然后,您可以使用两个批处理的数据集的输出并将它们连接起来,以实现均衡的批处理。