将张量流tf.data.Dataset FlatMapDataset转换为TensorSliceDataset

时间:2018-04-18 03:44:00

标签: python tensorflow vectorization tensorflow-datasets

我想将一个tf.Strings列表传递给.map(_parse_function)函数。

 def _parse_function(self, img_path):
        img_str = tf.read_file(img_path)
        img_decode = tf.image.decode_jpeg(img_str, channels=3)
        img_decode = tf.divide(tf.cast(img_decode , tf.float32),255)
        return img_decode

tf.data.Dataset的类型为TensorSliceDataset时,

dataset_from_slices = tf.data.Dataset.from_tensor_slices((tensor_with_filenames))

我可以干脆做 dataset_from_slices.map(_parse_function),有效。

但是,dataset_from_generator = tf.data.Dataset.from_generator(...)会返回Dataset类型的FlatMatDatasetdataset_from_generator.map(_parse_function)会出现以下错误:

InvalidArgumentError: Input filename tensor must be scalar, but had shape: [32]

如果我将第一行更改为:

img_str = tf.read_file(img_path[0])

这也有效但我只得到第一张图片,这不是我要找的。有什么建议吗?

1 个答案:

答案 0 :(得分:1)

听起来dataset_from_generator的元素已被批量处理。最简单的补救措施是使用tf.contrib.data.unbatch()将它们转换回单个元素:

# Each element is a vector of strings.
dataset_from_generator = tf.data.Dataset.from_generator(...)

# Converts each vector of strings into multiple individual elements.
dataset = dataset_from_generator.apply(tf.contrib.data.unbatch())

dataset = dataset.map(_parse_function)