如何将多个tfrecord数据集合并为一个数据集?

时间:2019-03-18 01:52:02

标签: python tensorflow deep-learning tfrecord

假设我有3个tfrecord文件,即neg.tfrecordpos1.tfrecordpos2.tfrecord

我的批量大小为500,包括300个负数据,100个pos1数据和100个pos2数据。如何获得所需的TFRecordDataset?

我将在keras.fit()(急切执行)中使用此TFRecordDataset对象。

我的tensorflow的版本是1.13.1。我在tf.data.Dataset中找到了API,例如interleaveconcatenatezip,但看来我无法解决问题。

之前,我尝试为每个数据集获取迭代器,然后在获取数据后手动进行连接,但是效率低下并且GPU利用率不高。

在此question中,我在下面使用interleave

tfrecord_files = ['neg.tfrecord', 'pos1.tfrecord', 'pos2.tfrecord']
dataset = tf.data.Dataset.from_tensor_slices(tfrecord_files)
def _parse(x):
    x = tf.data.TFRecordDataset(x)
    return x
dataset = dataset.interleave(_parse, cycle_length=4, block_length=1)
dataset = dataset.apply(tf.data.experimental.map_and_batch(_parse_image_function, 500))

我得到了这批:

neg pos1 pos2 neg pos1 pos2 ...............

但是我想要的是这个

neg neg neg pos1 pos2 neg neg neg pos1 pos2 .................

我该怎么办?

期待回答。

1 个答案:

答案 0 :(得分:1)

我使用字符串数据复制了您所说的内容:

import tensorflow as tf

def string_data(s):
    return tf.sparse.to_dense(tf.strings.split([s]), default_value='')[0]

data = [' '.join(['neg'] * 30), ' '.join(['pos1'] * 10), ' '.join(['pos2'] * 10)]
step_sizes = tf.constant([3, 1, 1], dtype=tf.int64)
ds = (tf.data.Dataset.from_tensor_slices((data, step_sizes))
      .interleave(lambda d, s: (tf.data.Dataset.from_tensor_slices(string_data(d))
                                .batch(s)),
                  cycle_length=len(data))
      .flat_map(tf.data.Dataset.from_tensor_slices))
iter = ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    while True:
        try:
            print(sess.run(iter).decode(), end=', ')
        except tf.errors.OutOfRangeError: break
    print()

输出:

neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, 

在实际使用情况下,您可以将data替换为文件名列表,并将tf.data.Dataset.from_tensor_slices(string_data(d))替换为tf.data.TFRecordDataset(d),但否则应该可以类似地工作。

编辑:我只是意识到您实际上想要一批以这种方式排序的所有元素,而不是一次只包含一个元素,所以我想您将不得不在末尾添加另一个batch调用。 / p>