有没有办法在Tensorflow中的另一个数据集中使用tf.data.Dataset?

时间:2018-02-27 23:59:23

标签: python tensorflow tensorflow-datasets

我正在进行细分。每个训练样本具有多个具有分割掩模的图像。我正在尝试编写input_fn以将每个训练样本中的所有蒙版图像合并为一个。 我打算使用两个Datasets,一个迭代样本文件夹,另一个读取所有掩码作为一个大批量,然后将它们合并到一个张量。

调用嵌套make_one_shot_iterator时出现错误。我知道这种方法有点拉伸,而且很可能是数据集,而不是为这种用法而设计的。但那我该如何解决这个问题,以免我使用tf.py_func?

以下是数据集的简化版本:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.
        list_files(sample_path+"/masks/*.png")
        .map(tf.read_file)
        .map(lambda x: tf.image.decode_image(x, channels=1))
        .batch(1024)) # maximum number of objects
    masks = masks_ds.make_one_shot_iterator().get_next()

    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds.map(read_sample)
# ...
sample = ds.make_one_shot_iterator().get_next()
# ...

1 个答案:

答案 0 :(得分:1)

如果嵌套数据集只有一个元素,则可以在嵌套数据集上使用tf.contrib.data.get_single_element(),而不是创建迭代器:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
                .map(tf.read_file)
                .map(lambda x: tf.image.decode_image(x, channels=1))
                .batch(1024)) # maximum number of objects
    masks = tf.contrib.data.get_single_element(masks_ds)
    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.map(read_sample)
sample = ds.make_one_shot_iterator().get_next()

此外,您可以使用tf.data.Dataset.flat_map()tf.data.Dataset.interleave()tf.contrib.data.parallel_interleave() transformationw在函数内执行嵌套Dataset计算,并将结果展平为单个{ {1}}。例如,要将所有样本放在一个Dataset中:

Dataset