我想将一个tf.Strings列表传递给.map(_parse_function)
函数。
def _parse_function(self, img_path):
img_str = tf.read_file(img_path)
img_decode = tf.image.decode_jpeg(img_str, channels=3)
img_decode = tf.divide(tf.cast(img_decode , tf.float32),255)
return img_decode
当tf.data.Dataset
的类型为TensorSliceDataset
时,
dataset_from_slices = tf.data.Dataset.from_tensor_slices((tensor_with_filenames))
我可以干脆做
dataset_from_slices.map(_parse_function
),有效。
但是,dataset_from_generator = tf.data.Dataset.from_generator(...)
会返回Dataset
类型的FlatMatDataset
,dataset_from_generator.map(_parse_function)
会出现以下错误:
InvalidArgumentError: Input filename tensor must be scalar, but had shape: [32]
如果我将第一行更改为:
img_str = tf.read_file(img_path[0])
这也有效但我只得到第一张图片,这不是我要找的。有什么建议吗?
答案 0 :(得分:1)
听起来dataset_from_generator
的元素已被批量处理。最简单的补救措施是使用tf.contrib.data.unbatch()
将它们转换回单个元素:
# Each element is a vector of strings.
dataset_from_generator = tf.data.Dataset.from_generator(...)
# Converts each vector of strings into multiple individual elements.
dataset = dataset_from_generator.apply(tf.contrib.data.unbatch())
dataset = dataset.map(_parse_function)