我大约有310万条记录,这些记录分为两个TFRecords文件。一个包含正类(〜217K),另一个包含负类(〜2.9MM)。我正在尝试使用Dataset API来交错记录,以使每个批次具有50/50的比例。为了用尽所有数据,我想重复一些肯定的例子,以便使用所有否定的例子。
现在最终发生的是它甚至开始,但是当肯定记录用完时,批处理中只会出现否定记录。
我相信如果文件名是.repeat()
,可以在下面的代码中添加train_pos.tfrecords
来解决此问题,但是,我无法弄清楚如何修改_get_files()
函数所以。我认为这可能是我所缺少的简单答案?
files = tf.data.Dataset.list_files("train_*.tfrecords")
def _get_files(x):
return tf.data.TFRecordDataset(x).shuffle(buffer_size=10000)
dataset = files.apply(tf.contrib.data.parallel_interleave(
lambda x: _get_files(x), cycle_length=2))\
.batch(self.batch_size)\
.map(_parse_line, num_parallel_calls=6)\
.repeat(1)\
.prefetch(2)
答案 0 :(得分:0)
您可以通过两次调用相关的TF记录tf.data.Dataset
来创建两个数据集:
files1 = tf.data.Dataset.list_files(...)
files2 = tf.data.Dataset.list_files(...)
,并使用repeat(-1)
使两个数据集取之不尽。
然后,您可以使用两个批处理的数据集的输出并将它们连接起来,以实现均衡的批处理。