avro文件中的tf.data.dataset

时间:2018-09-03 04:31:05

标签: python tensorflow keras tensorflow-datasets tfrecord

我正在尝试使用tf.data.DatasetTFRecordDataset并行化输入管道。

files = tf.data.Dataset.list_files("./data/*.avro")
dataset = tf.data.TFRecordDataset(files, num_parallel_reads=16)
dataset = dataset.apply(tf.contrib.data.map_and_batch(
    preprocess_fn, 512, num_parallel_batches=16) )

如果输入的是AVRO文件(如JSON),我不确定如何写preprocess_fn


当前,我正在使用tf.data.Dataset.from_generator并将由pyavroc或类似的Avro阅读器解析的Avro记录提供给它。但是我不确定如何将其并行化,因为from_generator方法没有num_parallel_reads选项可用。

def gen():
    for file in all_avro_files:
        x, y = read_local_avro_data(file)
        for i, sample in enumerate( x ):
            yield sample, y[i]

dataset = tf.data.Dataset.from_generator( gen, 
            (tf.float32, tf.float64),
            ( tf.TensorShape([13000]), tf.TensorShape([]) 
        ) 
    )

逐个文件读取显然是一个瓶颈,在耗尽前一批数据后,我看到所有内核都在等待数据。

如何优化这两种方法?

0 个答案:

没有答案