Tensorflow - 使用parallel_interleave从多个tfrecords中读取不同的block_lengths?

时间:2018-06-15 18:40:20

标签: python tensor tfrecord

我正在尝试阅读三种不同长度的大型tfrecords,并将它们全部并行读取:

files = [ filename1, filename2, filename3 ]

data = tf.data.TFRecordDataset(files)

data = data.apply(
    tf.contrib.data.parallel_interleave(
        lambda filename: tf.data.TFRecordDataset(data),
        cycle_length=3,block_length = [10,5,3]))

data = data.shuffle(
    buffer_size = 100)

data = data.apply(
    tf.contrib.data.map_and_batch(
        map_func=parse, 
        batch_size=100))

data = data.prefetch(10)

,但TensorFlow不允许每个文件源使用不同的块长度:

InvalidArgumentError: block_length must be a scalar

我可以使用不同的小批量大小创建三个不同的数据集,但这需要3倍的资源,而这不是我的机器限制所给出的选项。

有哪些可能的解决方案?

1 个答案:

答案 0 :(得分:0)

这是答案,我想出了如何在我的约束中做到这一点。

为每个文件创建数据集,为每个文件定义每个迷你批量大小,并将get_next()输出连接在一起。这适合我的机器并高效运行。