如何在TF 2中使用带有自定义函数的tf.data.Dataset.interleave()?

时间:2020-05-30 17:02:29

标签: tensorflow tensorflow2.0 tensorflow-datasets

我正在使用TF 2.2,并且正在尝试使用tf.data创建管道。

以下可以正常工作:

def load_image(filePath, label):

    print('Loading File: {}' + filePath)
    raw_bytes = tf.io.read_file(filePath)
    image = tf.io.decode_image(raw_bytes, expand_animations = False)

    return image, label

# TrainDS Pipeline
trainDS = getDataset()
trainDS = trainDS.shuffle(size['train'])
trainDS = trainDS.map(load_image, num_parallel_calls=AUTOTUNE)

for d in trainDS:
    print('Image: {} - Label: {}'.format(d[0], d[1]))

我想将load_image()Dataset.interleave()一起使用。然后我尝试了:

# TrainDS Pipeline
trainDS = getDataset()
trainDS = trainDS.shuffle(size['train'])
trainDS = trainDS.interleave(lambda x, y: load_image_with_label(x, y), cycle_length=4)

for d in trainDS:
    print('Image: {} - Label: {}'.format(d[0], d[1]))

但是我遇到了以下错误:

Exception has occurred: TypeError
`map_func` must return a `Dataset` object. Got <class 'tuple'>
  File "/data/dev/train_daninhas.py", line 44, in <module>
    trainDS = trainDS.interleave(lambda x, y: load_image_with_label(x, y), cycle_length=4)

如何修改我的代码以使Dataset.interleave()load_image()一起并行读取图像?

1 个答案:

答案 0 :(得分:2)

正如错误所暗示的,您需要修改load_image以便它返回一个Dataset对象,我展示了一个带有两个图像的示例,说明了如何在{{1} }:

tensorflow 2.2.0

希望这会有所帮助!