我正在尝试使用tf.data.Dataset
和TFRecordDataset
并行化输入管道。
files = tf.data.Dataset.list_files("./data/*.avro")
dataset = tf.data.TFRecordDataset(files, num_parallel_reads=16)
dataset = dataset.apply(tf.contrib.data.map_and_batch(
preprocess_fn, 512, num_parallel_batches=16) )
如果输入的是AVRO文件(如JSON),我不确定如何写preprocess_fn
。
当前,我正在使用tf.data.Dataset.from_generator
并将由pyavroc
或类似的Avro阅读器解析的Avro记录提供给它。但是我不确定如何将其并行化,因为from_generator
方法没有num_parallel_reads
选项可用。
def gen():
for file in all_avro_files:
x, y = read_local_avro_data(file)
for i, sample in enumerate( x ):
yield sample, y[i]
dataset = tf.data.Dataset.from_generator( gen,
(tf.float32, tf.float64),
( tf.TensorShape([13000]), tf.TensorShape([])
)
)
逐个文件读取显然是一个瓶颈,在耗尽前一批数据后,我看到所有内核都在等待数据。
如何优化这两种方法?