来自Keras模型中图像列表的TensorFlow数据集

时间:2019-02-28 19:55:42

标签: python tensorflow keras

我试图了解如何读取本地图像,将其用作TensorFlow Dataset并使用TF Dataset训练Keras模型。我正在追踪TF Keras MNIST TPU tutorial。我想要阅读我的图像集并对其进行训练的唯一区别。

假设我有图像列表(文件名)和相应的标签列表。

files = [...] # list of file names
labels = [...] # list of labels (integers)
images = tf.constant(files) # or tf.convert_to_tensor(files)
labels = tf.constant(labels) # or tf.convert_to_tensor(labels)
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.shuffle(len(files))
dataset = dataset.repeat()
dataset = dataset.map(parse_function).batch(batch_size)

parse_function是一个简单的函数,它读取输入的文件名并产生图像数据和相应的标签,例如

def parse_function(filename, label):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_image(image_string)
    image = tf.cast(image_decoded, tf.float32)
    return image, label

这时,我有一个dataset,它是一个tf.data.Dataset类型(更准确地说是tf.data.BatchDataset),我将其从{{3}传递给了keras模型trained_model },例如

history = trained_model.fit(dataset, ...)

但是,此时代码因以下错误而中断:

AttributeError: 'BatchDataset' object has no attribute 'ndim'

错误来自keras,它会像这样对给定的输入执行检查

from keras import backend as K
K.is_tensor(dataset) # which returns false

Keras尝试确定输入的类型,由于它不是张量,因此假定它是numpy数组并尝试获取其维数。这就是错误发生的原因。

我的问题如下:

  • 我可以正确读取TF数据集吗?我在互联网上查阅了大量示例,似乎我正在按照人们的建议阅读
  • 为什么我的数据集不是张量?可能是我需要执行其他转换,但TF tutorial
  • 并非如此
  • 为什么在TF tutorial中一切都可以与tf数据集一起使用,我真的看不出它们在读取MNIST数据(这是不同的数据格式,但最终它们获得图像)的方式上有什么不同,以及我在这里。

任何建议将不胜感激。

请注意,即使TF tutorial也是关于TPU的,其结构也使其可以同时在TPU和CPU / GPU上使用。

1 个答案:

答案 0 :(得分:0)

原来问题出在使用Keras模型。 TF教程中的示例依赖于使用tf.keras模块的Keras模型构建(所有层,模型等均来自tf.keras)。当我使用的模型(DenseNet)依赖于纯keras模块时,即所有层都来自keras模块而不是tf.keras。这将导致在keras模型的fit方法中检查tf.data.Dataset的ndim。一旦我将DenseNet调整为使用tf.keras图层,一切将再次正常工作。