如何加载我自己的数据集以在Tensorflow中创建DCGAN?

时间:2020-06-02 00:25:03

标签: python tensorflow mnist generative-adversarial-network

我尝试按照教程here创建DCGAN,仅更改输入数据,但无法将数据加载到适合模型训练的适当形状中。这是我的代码:

import pathlib
data_dir = pathlib.Path(data_dir)

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=1)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])
def process_path(file_path):
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img
labeled_ds = list_ds.map(process_path)

然后我从那里复制了本教程,修改了参数以满足我的数据集,然后得到了最终结果:

train(labeled_ds, EPOCHS)

但这样做时出现以下错误。

ValueError: Input 0 of layer sequential_1 is incompatible with the layer: expected ndim=4, found ndim=3. Full shape received: [224, 320, 1]

如何向tensorflow数据集添加额外的维度?或者,相对于本教程中如何加载MNIST数据集,我在这里做错什么了吗?谢谢

1 个答案:

答案 0 :(得分:0)

您似乎缺少了批次大小(第0维)。

最简单,最灵活的添加方法是在.batch()上使用tf.data.Dataset函数:

batch_size = 1  # Or something else.
labeled_ds = list_ds.map(process_path).batch(batch_size)

这应该修改张量的形状,从
由模型期望的[height, width, channels][batch_size, height, width, channels]
希望以上内容对您有用!