获取文件名的例子来自tf,parse_exampes

时间:2018-04-09 18:16:16

标签: tensorflow tensorflow-datasets

我正在tensorflow中编写一个数据输入管道,它使用一堆带有不同示例(类型)的tfrecord文件。

我使用的代码如下:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

但是我希望我的parse_function与file1.tfrecord不同,而不是file2.tfrecord。我如何实现这一目标。在parse_example中有什么知道某个特定示例来自哪个文件?

1 个答案:

答案 0 :(得分:1)

您可以使用Dataset.flat_map()转换为每条记录包含文件名,如下所示:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
filenames = tf.data.from_tensor_slices(filenames)

# `Dataset.flat_map()` creates a nested dataset from each element in `filenames`.
#
# For each file in filename, zip together the filename (repeated infinitely) with
# the records read from that file.
dataset = filenames.flat_map(
    lambda fn: tf.data.Dataset.zip((tf.data.Dataset.from_tensors(fn).repeat(None),
                                    tf.data.TFRecordDataset(fn))))

# The _parse_function can now be modified to take both the filename and the record.
dataset = dataset.map(lambda fn, record: _parse_function(fn, record))