我正在tensorflow中编写一个数据输入管道,它使用一堆带有不同示例(类型)的tfrecord文件。
我使用的代码如下:
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
但是我希望我的parse_function与file1.tfrecord不同,而不是file2.tfrecord。我如何实现这一目标。在parse_example中有什么知道某个特定示例来自哪个文件?
答案 0 :(得分:1)
您可以使用Dataset.flat_map()
转换为每条记录包含文件名,如下所示:
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
filenames = tf.data.from_tensor_slices(filenames)
# `Dataset.flat_map()` creates a nested dataset from each element in `filenames`.
#
# For each file in filename, zip together the filename (repeated infinitely) with
# the records read from that file.
dataset = filenames.flat_map(
lambda fn: tf.data.Dataset.zip((tf.data.Dataset.from_tensors(fn).repeat(None),
tf.data.TFRecordDataset(fn))))
# The _parse_function can now be modified to take both the filename and the record.
dataset = dataset.map(lambda fn, record: _parse_function(fn, record))