与线程/队列相比,tf.data.Dataset输入管道提供了不良结果

时间:2019-02-18 11:59:06

标签: tensorflow tensorflow-datasets

之前,我使用线程和队列作为数据管道,并且两个GPU的利用率都很高(数据是动态创建的)。我想使用tf数据集,但是我很难复制结果。

我尝试了很多方法。由于我是动态创建数据的,因此from_generator()方法似乎很完美。您在下面看到的代码是我最后一次尝试。尽管我正在使用map()函数来处理生成的图像,但是似乎在创建数据方面存在瓶颈。我在下面的代码中尝试过的我想以某种方式“多线程化”生成器,因此同时会有更多数据输入。但是到目前为止,没有更好的结果。

def generator(n):
    with tf.device('/cpu:0'):
        while True:
            ...
            yield image, label

def get_generator(n):
    return partial(generator, n)

def dataset(n):
    return tf.data.Dataset.from_generator(get_generator(n), output_types=(tf.float32, tf.float32), output_shapes=(tf.TensorShape([None,None,1]),tf.TensorShape([None,None,1])))

def input_fn():
# ds = tf.data.Dataset.from_generator(generator, output_types=(tf.float32, tf.float32), output_shapes=(tf.TensorShape([None,None,1]),tf.TensorShape([None,None,1])))
    ds = tf.data.Dataset.range(BATCH_SIZE).apply(tf.data.experimental.parallel_interleave(dataset, cycle_length=BATCH_SIZE))
    ds = ds.map(map_func=lambda img, lbl: processImage(img, lbl))
    ds = ds.shuffle(SHUFFLE_SIZE)
    ds = ds.batch(BATCH_SIZE)
    ds = ds.prefetch(1)
return ds

预期结果将是较高的GPU利用率(> 80%),但目前实际上仅为10/20%。

1 个答案:

答案 0 :(得分:0)

您可以改用tf.data.Dataset.from_tensor_slices。 只需传递图片/标签路径即可。该函数接受文件名作为参数。

def input_func():
    dataset = tf.data.Dataset.from_tensor_slices(images_path, labels_path)
    dataset = dataset.shuffle().repeat()
    ...
    return dataset