遵循 Keras 的本教程
https://keras.io/examples/generative/dcgan_overriding_train_step/#override-trainstep
我想生成 256 x 256 分辨率的图像,但我需要降低 google colab 的批量大小
我有代码:
from keras.preprocessing import image_dataset_from_directory, image
from matplotlib import pyplot as plt
import matplotlib
from keras.models import Sequential
from keras.layers import Conv2D, LeakyReLU, Flatten, Dropout, Dense, Reshape, Conv2DTranspose
import keras
import tensorflow as tf
DIMENSION = 256
SHAPE = (DIMENSION, DIMENSION, 3)
LATENT_DIM = 256
DIR = 'drive/MyDrive/data/test'
def get_dataset():
dataset = image_dataset_from_directory(DIR, label_mode=None, image_size=(DIMENSION, DIMENSION), batch_size=32)
dataset = dataset.map(lambda x: x / 255.0)
return dataset
def create_discriminator():
model = Sequential()
model.add(Conv2D(DIMENSION, kernel_size=4, strides=2, padding='same', input_shape=SHAPE))
model.add(LeakyReLU(0.2))
model.add(Conv2D(DIMENSION * 2, kernel_size=4, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2D(DIMENSION * 2, kernel_size=4, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))
return model
def create_generator():
model = Sequential()
model.add(Dense(DIMENSION * LATENT_DIM, input_shape=(LATENT_DIM,)))
model.add(Reshape((16, 16, LATENT_DIM)))
model.add(Conv2DTranspose(LATENT_DIM, kernel_size=4, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2DTranspose(LATENT_DIM * 2, kernel_size=4, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2DTranspose(LATENT_DIM * 4, kernel_size=4, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2DTranspose(LATENT_DIM * 8, kernel_size=4, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2D(3, kernel_size=5, padding='same', activation='sigmoid'))
return model
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, loss_fn):
super(GAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
self.d_loss_metric = keras.metrics.Mean(name='d_loss')
self.g_loss_metric = keras.metrics.Mean(name='g_loss')
@property
def metrics(self):
return [self.d_loss_metric, self.g_loss_metric]
def train_step(self, real_images):
batch_size = 4
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
generated_images = self.generator(random_latent_vectors)
combined_images = tf.concat([generated_images, real_images], axis=0)
print(combined_images)
labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
labels += 0.05 * tf.random.uniform(tf.shape(labels))
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
misleading_labels = tf.zeros((batch_size, 1))
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
self.d_loss_metric.update_state(d_loss)
self.g_loss_metric.update_state(g_loss)
return {
'd_loss': self.d_loss_metric.result(),
'g_loss': self.g_loss_metric.result()
}
class GANMonitor(keras.callbacks.Callback):
def __init__(self, num_img=3, latent_dim=LATENT_DIM):
self.num_img = num_img
self.latent_dim = latent_dim
def on_epoch_end(self, epoch, logs=None):
random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
generated_images = self.model.generator(random_latent_vectors)
generated_images.numpy()
for i in range(self.num_img):
img = image.array_to_img(generated_images[i])
img.save('drive/MyDrive/data/checkpoints/generated_img_%09d_%d.png' % (epoch, i))
self.model.generator.save('drive/MyDrive/data/gen_model.h5')
dataset = get_dataset()
epochs = 10000
gan = GAN(discriminator=create_discriminator(), generator=create_generator(), latent_dim=LATENT_DIM)
gan.compile(d_optimizer=keras.optimizers.Adam(learning_rate=0.0001), g_optimizer=keras.optimizers.Adam(learning_rate=0.0001), loss_fn=keras.losses.BinaryCrossentropy())
gan.fit(dataset, epochs=epochs, callbacks=[GANMonitor(num_img=8, latent_dim=LATENT_DIM)])
不幸的是我得到了
<ipython-input-34-7f01c745080b> in <module>()
118 gan.compile(d_optimizer=keras.optimizers.Adam(learning_rate=0.0001), g_optimizer=keras.optimizers.Adam(learning_rate=0.0001), loss_fn=keras.losses.BinaryCrossentropy())
119
--> 120 gan.fit(dataset, epochs=epochs, callbacks=[GANMonitor(num_img=8, latent_dim=LATENT_DIM)])
6 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1098 _r=1):
1099 callbacks.on_train_batch_begin(step)
-> 1100 tmp_logs = self.train_function(iterator)
1101 if data_handler.should_sync:
1102 context.async_wait()
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
826 tracing_count = self.experimental_get_tracing_count()
827 with trace.Trace(self._name) as tm:
--> 828 result = self._call(*args, **kwds)
829 compiler = "xla" if self._experimental_compile else "nonXla"
830 new_tracing_count = self.experimental_get_tracing_count()
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
886 # Lifting succeeded, so variables are initialized and we can run the
887 # stateless function.
--> 888 return self._stateless_fn(*args, **kwds)
889 else:
890 _, _, _, filtered_flat_args = \
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
2941 filtered_flat_args) = self._maybe_define_function(args, kwargs)
2942 return graph_function._call_flat(
-> 2943 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
2944
2945 @property
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1917 # No tape is watching; skip to running the function.
1918 return self._build_call_outputs(self._inference_function.call(
-> 1919 ctx, args, cancellation_manager=cancellation_manager))
1920 forward_backward = self._select_forward_and_backward_functions(
1921 args,
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
558 inputs=args,
559 attrs=attrs,
--> 560 ctx=ctx)
561 else:
562 outputs = execute.execute_with_cancellation(
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: Incompatible shapes: [8,1] vs. [36,1]
[[node binary_crossentropy/logistic_loss/mul (defined at <ipython-input-34-7f01c745080b>:79) ]] [Op:__inference_train_function_66509]
鉴别器摘要:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 128, 128, 256) 12544
_________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 128, 128, 256) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 64, 64, 512) 2097664
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 64, 64, 512) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 32, 32, 512) 4194816
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 32, 32, 512) 0
_________________________________________________________________
flatten (Flatten) (None, 524288) 0
_________________________________________________________________
dropout (Dropout) (None, 524288) 0
_________________________________________________________________
dense (Dense) (None, 1) 524289
=================================================================
Total params: 6,829,313
Trainable params: 6,829,313
Non-trainable params: 0
发电机对称:
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 65536) 16842752
_________________________________________________________________
reshape (Reshape) (None, 16, 16, 256) 0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 32, 32, 256) 1048832
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 32, 32, 256) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 64, 64, 512) 2097664
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 64, 64, 512) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 128, 128, 1024) 8389632
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 128, 128, 1024) 0
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 256, 256, 2048) 33556480
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 256, 256, 2048) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 256, 256, 3) 153603
=================================================================
Total params: 62,088,963
Trainable params: 62,088,963
Non-trainable params: 0
我不知道我做错了什么,tf.shape(real_images)[0] 到底做了什么,或者是否有任何方法可以控制 real_images 的大小。
答案 0 :(得分:0)
我找到了解决方案,就是在这一行中从 dir 读取图像:
dataset = image_dataset_from_directory(DIR, label_mode=None, image_size=(DIMENSION, DIMENSION), batch_size=32)
在此处更改 bach_size 可以解决问题。