如何将多个数据集合并为一个数据集?

时间:2019-03-14 04:09:24

标签: python tensorflow tfrecord tf.keras eager-execution

假设我有3个tfrecord文件,即neg.tfrecordpos1.tfrecordpos2.tfrecord

我用

dataset = tf.data.TFRecordDataset(tfrecord_file)

此代码创建3个数据集对象。

我的批量大小为400,包括200个负数据,100个pos1数据和100个pos2数据。如何获得所需的数据集?

我将在keras.fit()(急切执行)中使用此数据集对象。

我的tensorflow的版本是1.13.1。

之前,我尝试为每个数据集获取迭代器,然后在获取数据后手动进行连接,但是效率低下并且GPU利用率不高。

1 个答案:

答案 0 :(得分:1)

您可以使用interleave

filenames = [tfrecord_file1, tfrecord_file2]
dataset = (Dataset.from_tensor_slices(filenames).interleave(lambda x:TFRecordDataset(x)
dataset = dataset.map(parse_fn)
...

或者您甚至可以尝试并行交错。参见https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#interleave https://www.tensorflow.org/api_docs/python/tf/data/experimental/parallel_interleave