使用旧的输入管道API,我可以这样做:
filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
然后将文件名传递给其他队列,例如:
reader = tf.TFRecordReader()
_, serialized_example = reader.read_up_to(filename_queue, n)
如何使用数据集-API实现类似的行为?
tf.data.TFRecordDataset()
期望文件名的张量按固定顺序排列。
答案 0 :(得分:3)
开始按顺序阅读它们,shuffle之后:
BUFFER_SIZE = 1000 # arbitrary number
# define filenames somewhere, e.g. via glob
dataset = tf.data.TFRecordDataset(filenames).shuffle(BUFFER_SIZE)
this question的输入管道让我了解了如何使用数据集API实现文件名改组:
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.shuffle(BUFFER_SIZE) # doesn't need to be big
dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(decode_example, num_parallel_calls=5) # add your decoding logic here
# further processing of the dataset
这会将一个文件的所有数据放在下一个文件之前,依此类推。文件被洗牌,但其中的数据将以相同的顺序生成。
您也可以将dataset.flat_map
替换为interleave
以同时处理多个文件并从每个文件中返回样本:
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)
注意: interleave
实际上并不在多个线程中运行,而是循环操作。有关真正的并行处理,请参阅parallel_interleave
答案 1 :(得分:-1)
当前的Tensorflow版本(02/2018中的v1.5)似乎不支持数据集API中的文件名洗牌。这是一个简单的使用numpy的工作:
import numpy as np
import tensorflow as tf
myShuffledFileList = np.random.choice(myInputFileList, size=len(myInputFileList), replace=False).tolist()
dataset = tf.data.TFRecordDataset(myShuffledFileList)