假设我有3个tfrecord文件,即neg.tfrecord
,pos1.tfrecord
,pos2.tfrecord
。
我用
dataset = tf.data.TFRecordDataset(tfrecord_file)
此代码创建3个数据集对象。
我的批量大小为400,包括200个负数据,100个pos1数据和100个pos2数据。如何获得所需的数据集?
我将在keras.fit()(急切执行)中使用此数据集对象。
我的tensorflow的版本是1.13.1。
之前,我尝试为每个数据集获取迭代器,然后在获取数据后手动进行连接,但是效率低下并且GPU利用率不高。
答案 0 :(得分:1)
您可以使用interleave
filenames = [tfrecord_file1, tfrecord_file2]
dataset = (Dataset.from_tensor_slices(filenames).interleave(lambda x:TFRecordDataset(x)
dataset = dataset.map(parse_fn)
...
或者您甚至可以尝试并行交错。参见https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#interleave https://www.tensorflow.org/api_docs/python/tf/data/experimental/parallel_interleave