从每个包含10个样本的文件中加载我的输入,并以每批次4个样本进行批处理,我得到了大小不等的4、4、2、4、4、2等批次,而不是将连续样本合并整理数据集后,文件符合我的预期。
我正在使用TensorFlow 1.8.0。为了将文件中的数据放入我的数据集对象中,我遵循了this answer。我的输入管道如下所示:
# Initialize dataset on files
dataset = tf.data.Dataset.list_files(input_files_list)
# Pre-process data in parallel
def preprocess_fn(input_file):
# lots of logic here...
return input1, input2, input3
map_fn = lambda input_file: tf.py_func(
preprocess_fn, [input_file], [tf.float32, tf.float32, tf.float32])
dataset = dataset.map(map_func=map_fn, num_parallel_calls=4)
# Flatten from files to samples
dataset = dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))
dataset = dataset.batch(batch_size=4)
dataset = dataset.prefetch(buffer_size=8)
但是我看到的是样本实际上并未在输入文件之间连接,因此批次大小是不均匀的。我认为这是因为flat_map()
将每个元素(来自文件的所有输入样本)映射到数据集-因此在flat_map()
之后,我的数据集实际上是数据集的一个数据集,并且每个嵌套的数据集都是分别进行批处理的。
但这不是我想要的。如何连接嵌套的数据集,或以其他方式展平数据集,以便可以将来自不同文件的样本一起批处理?
答案 0 :(得分:0)
我在TF 2.0上也遇到类似的问题,并且使用了unbatch
函数,就像这样:
dataset = dataset.flat_map(lambda f: parse_function(f)).apply(tf.data.experimental.unbatch())
我相信可以使用TF 1.8 tf.contrib.data.unbatch。