训练 GAN 时出现 InvalidArgumentError

时间:2021-04-04 00:39:05

标签: python-3.x tensorflow keras deep-learning generative-adversarial-network

遵循 Keras 的本教程

https://keras.io/examples/generative/dcgan_overriding_train_step/#override-trainstep

我想生成 256 x 256 分辨率的图像,但我需要降低 google colab 的批量大小

我有代码:

from keras.preprocessing import image_dataset_from_directory, image
from matplotlib import pyplot as plt
import matplotlib
from keras.models import Sequential
from keras.layers import Conv2D, LeakyReLU, Flatten, Dropout, Dense, Reshape, Conv2DTranspose
import keras
import tensorflow as tf

DIMENSION = 256
SHAPE = (DIMENSION, DIMENSION, 3)

LATENT_DIM = 256

DIR = 'drive/MyDrive/data/test'

def get_dataset():
    dataset = image_dataset_from_directory(DIR, label_mode=None, image_size=(DIMENSION, DIMENSION), batch_size=32)
    dataset = dataset.map(lambda x: x / 255.0)
    return dataset

def create_discriminator():
    model = Sequential()
    model.add(Conv2D(DIMENSION, kernel_size=4, strides=2, padding='same', input_shape=SHAPE))
    model.add(LeakyReLU(0.2))
    model.add(Conv2D(DIMENSION * 2, kernel_size=4, strides=2, padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(Conv2D(DIMENSION * 2, kernel_size=4, strides=2, padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(Flatten())
    model.add(Dropout(0.2))
    model.add(Dense(1, activation='sigmoid'))
    return model

def create_generator():
    model = Sequential()
    model.add(Dense(DIMENSION * LATENT_DIM, input_shape=(LATENT_DIM,)))
    model.add(Reshape((16, 16, LATENT_DIM)))
    model.add(Conv2DTranspose(LATENT_DIM, kernel_size=4, strides=2, padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(Conv2DTranspose(LATENT_DIM * 2, kernel_size=4, strides=2, padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(Conv2DTranspose(LATENT_DIM * 4, kernel_size=4, strides=2, padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(Conv2DTranspose(LATENT_DIM * 8, kernel_size=4, strides=2, padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(Conv2D(3, kernel_size=5, padding='same', activation='sigmoid'))
    return model

class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = keras.metrics.Mean(name='d_loss')
        self.g_loss_metric = keras.metrics.Mean(name='g_loss')
    
    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, real_images):
        batch_size = 4
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        generated_images = self.generator(random_latent_vectors)
        combined_images = tf.concat([generated_images, real_images], axis=0)
        print(combined_images)
        labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
            grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
            self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        misleading_labels = tf.zeros((batch_size, 1))

        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
            grads = tape.gradient(g_loss, self.generator.trainable_weights)
            self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {
            'd_loss': self.d_loss_metric.result(),
            'g_loss': self.g_loss_metric.result()
        }

class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=3, latent_dim=LATENT_DIM):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images.numpy()
        for i in range(self.num_img):
            img = image.array_to_img(generated_images[i])
            img.save('drive/MyDrive/data/checkpoints/generated_img_%09d_%d.png' % (epoch, i))
        self.model.generator.save('drive/MyDrive/data/gen_model.h5')

dataset = get_dataset()

epochs = 10000

gan = GAN(discriminator=create_discriminator(), generator=create_generator(), latent_dim=LATENT_DIM)
gan.compile(d_optimizer=keras.optimizers.Adam(learning_rate=0.0001), g_optimizer=keras.optimizers.Adam(learning_rate=0.0001), loss_fn=keras.losses.BinaryCrossentropy())

gan.fit(dataset, epochs=epochs, callbacks=[GANMonitor(num_img=8, latent_dim=LATENT_DIM)])

不幸的是我得到了

<ipython-input-34-7f01c745080b> in <module>()
    118 gan.compile(d_optimizer=keras.optimizers.Adam(learning_rate=0.0001), g_optimizer=keras.optimizers.Adam(learning_rate=0.0001), loss_fn=keras.losses.BinaryCrossentropy())
    119 
--> 120 gan.fit(dataset, epochs=epochs, callbacks=[GANMonitor(num_img=8, latent_dim=LATENT_DIM)])

6 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1098                 _r=1):
   1099               callbacks.on_train_batch_begin(step)
-> 1100               tmp_logs = self.train_function(iterator)
   1101               if data_handler.should_sync:
   1102                 context.async_wait()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    886         # Lifting succeeded, so variables are initialized and we can run the
    887         # stateless function.
--> 888         return self._stateless_fn(*args, **kwds)
    889     else:
    890       _, _, _, filtered_flat_args = \

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   2941        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   2942     return graph_function._call_flat(
-> 2943         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   2944 
   2945   @property

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1917       # No tape is watching; skip to running the function.
   1918       return self._build_call_outputs(self._inference_function.call(
-> 1919           ctx, args, cancellation_manager=cancellation_manager))
   1920     forward_backward = self._select_forward_and_backward_functions(
   1921         args,

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    558               inputs=args,
    559               attrs=attrs,
--> 560               ctx=ctx)
    561         else:
    562           outputs = execute.execute_with_cancellation(

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  Incompatible shapes: [8,1] vs. [36,1]
     [[node binary_crossentropy/logistic_loss/mul (defined at <ipython-input-34-7f01c745080b>:79) ]] [Op:__inference_train_function_66509]

鉴别器摘要:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 128, 128, 256)     12544     
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 128, 128, 256)     0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 64, 64, 512)       2097664   
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 64, 64, 512)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 512)       4194816   
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 32, 32, 512)       0         
_________________________________________________________________
flatten (Flatten)            (None, 524288)            0         
_________________________________________________________________
dropout (Dropout)            (None, 524288)            0         
_________________________________________________________________
dense (Dense)                (None, 1)                 524289    
=================================================================
Total params: 6,829,313
Trainable params: 6,829,313
Non-trainable params: 0

发电机对称:

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 65536)             16842752  
_________________________________________________________________
reshape (Reshape)            (None, 16, 16, 256)       0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 32, 32, 256)       1048832   
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 32, 32, 256)       0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 64, 64, 512)       2097664   
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 64, 64, 512)       0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 128, 128, 1024)    8389632   
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 128, 128, 1024)    0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 256, 256, 2048)    33556480  
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 256, 256, 2048)    0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 256, 256, 3)       153603    
=================================================================
Total params: 62,088,963
Trainable params: 62,088,963
Non-trainable params: 0

我不知道我做错了什么,tf.shape(real_images)[0] 到底做了什么,或者是否有任何方法可以控制 real_images 的大小。

1 个答案:

答案 0 :(得分:0)

我找到了解决方案,就是在这一行中从 dir 读取图像:

dataset = image_dataset_from_directory(DIR, label_mode=None, image_size=(DIMENSION, DIMENSION), batch_size=32)

在此处更改 bach_size 可以解决问题。