我正在尝试在Google Cloud的TPU单元上训练GAN网络。我在https://github.com/deepak112/Keras-SRGAN/blob/master/simplified/train.py处修改了代码,主要移植为使用Dataset API和我自己的数据集。当我运行代码时,它的训练速度非常慢,在TPU上的CPU和内存使用率都很低-好像我做了很多不好的练习和瓶颈。有人看到我在做什么错吗?
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import keras
from keras.applications.vgg19 import VGG19
import keras.backend as K
from keras.layers import add, Dense, Input, Lambda
from keras.layers.advanced_activations import LeakyReLU, PReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers.core import Activation, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model
import numpy as np
from tensorflow.keras.optimizers import SGD, Adam
import tensorflow as tf
import tensorflow_datasets as tfds
image_lr_shape = (64,64,3)
image_hr_shape = (256,256,3)
image_lr_dir = "gs://******/Dataset/Small/*.png"
image_hr_dir = "gs://******/Dataset/Large/*.png"
model_dir = "gs://******/Models"
epochs = 1000
batch_size = 128
def vgg_loss(y_true, y_pred):
vgg19 = VGG19(include_top=False, weights="imagenet", input_shape=image_hr_shape)
vgg19.trainable = False
for l in vgg19.layers:
l.trainable = False
loss_model = Model(inputs=vgg19.input,outputs=vgg19.get_layer("block5_conv4").output)
loss_model.trainable = False
return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))
def residual_block(model,kernel_size,filters,strides):
gen = model
model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
model = BatchNormalization(momentum=0.5)(model)
model = PReLU(alpha_initializer="zeros",alpha_regularizer=None,alpha_constraint=None,shared_axes=[1,2])(model)
model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
model = BatchNormalization(momentum=0.5)(model)
model = add([gen, model])
return model
def up_sampling_block(model,kernel_size,filters,strides):
model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
model = UpSampling2D(size=2)(model)
model = LeakyReLU(alpha=0.2)(model)
return model
def discriminator_block(model, filters, kernel_size, strides):
model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
model = BatchNormalization(momentum=0.5)(model)
model = LeakyReLU(alpha=0.2)(model)
return model
def generator_network():
gen_input = Input(shape=image_lr_shape)
model = Conv2D(filters=64,kernel_size=9,strides=1,padding="same")(gen_input)
model = PReLU(alpha_initializer="zeros",alpha_regularizer=None,alpha_constraint=None,shared_axes=[1,2])(model)
gen_model = model
for index in range(16):
model = residual_block(model,3,64,1)
model = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(model)
model = BatchNormalization(momentum=0.5)(model)
model = add([gen_model,model])
for index in range(2):
model = up_sampling_block(model,3,256,1)
model = Conv2D(filters=3,kernel_size=9,strides=1,padding="same")(model)
model = Activation("tanh")(model)
generator_model = Model(inputs=gen_input,outputs=model)
return generator_model
def discriminator_network():
dis_input = Input(shape=image_hr_shape)
model = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(dis_input)
model = LeakyReLU(alpha=0.2)(model)
model = discriminator_block(model,64,3,2)
model = discriminator_block(model,128,3,1)
model = discriminator_block(model,128,3,2)
model = discriminator_block(model,256,3,1)
model = discriminator_block(model,256,3,2)
model = discriminator_block(model,512,3,1)
model = discriminator_block(model,512,3,2)
model = Flatten()(model)
model = Dense(1024)(model)
model = LeakyReLU(alpha=0.2)(model)
model = Dense(1)(model)
model = Activation("sigmoid")(model)
discriminator_model = Model(inputs=dis_input,outputs=model)
return discriminator_model
def get_gan_network(discriminator,generator,optimizer):
discriminator.trainable=False
gan_input = Input(shape=image_lr_shape)
x = generator(gan_input)
gan_output = discriminator(x)
gan = Model(inputs=gan_input,outputs=[x,gan_output])
gan.compile(loss=[vgg_loss,"binary_crossentropy"],loss_weights=[1.,1e-3], optimizer=optimizer)
return gan
def load_image(path):
img = tf.io.read_file(path)
img = tf.image.decode_png(img,channels=3)
img = tf.image.convert_image_dtype(img,tf.float32)
return img
#============================================WARNING: Dataset size hardcoded!
print("Connecting to TPU...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="******",zone="******",project="******")
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
print("Loading dataset...")
image_hr = tf.data.Dataset.list_files(image_hr_dir,shuffle=False).map(load_image).shuffle(6777,seed=12345679,reshuffle_each_iteration=True).batch(batch_size).prefetch(4)
image_lr = tf.data.Dataset.list_files(image_lr_dir,shuffle=False).map(load_image).shuffle(6777,seed=12345679,reshuffle_each_iteration=True).batch(batch_size).prefetch(4)
batch_count = int(6777/batch_size)
print("Compiling neural network---")
generator = generator_network()
discriminator = discriminator_network()
adam = Adam(learning_rate=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-8)
generator.compile(loss=vgg_loss,optimizer=adam)
discriminator.compile(loss="binary_crossentropy",optimizer=adam)
gan = get_gan_network(discriminator,generator,adam)
print("Starting training...")
for e in range(1,epochs+1):
print("-"*15,"Epoch %d" % e, "-"*15)
for b in range(batch_count):
print("-"*10,"Batch %d" % b,"-"*10)
batch_hr = np.stack(tfds.as_numpy(image_hr.take(1)))
batch_lr = np.stack(tfds.as_numpy(image_lr.take(1)))
batch_sr = generator.predict(batch_lr)
real_y = tf.random.uniform(shape=(batch_size,1),minval=0.8,maxval=1)
fake_y = tf.random.uniform(shape=(batch_size,1),minval=0,maxval=0.2)
discriminator.trainable = True
d_loss_real = discriminator.train_on_batch(batch_hr,real_y)
d_loss_fake = discriminator.train_on_batch(batch_sr,fake_y)
batch_hr = np.stack(tfds.as_numpy(image_hr.take(1)))
batch_lr = np.stack(tfds.as_numpy(image_lr.take(1)))
gan_y = tf.random.uniform(shape=(batch_size,1),minval=0.8,maxval=1)
discriminator.trainable = False
loss_gan = gan.train_on_batch(batch_lr,[batch_hr,gan_y])
print("Loss d_real,Loss d_fake, Loss network")
print(d_loss_real,d_loss_fake,loss_gan)
os.makedirs(os.path.join(model_dir,str(e)))
generator.save(os.path.join(model_dir,str(e),"gen.h5"))
discriminator.save(os.path.join(model_dir,str(e),"dis.h5"))
gan.save(os.path.join(model_dir,str(e),"gan.h5"))
更新:我添加了更多的调试打印,它在batch_hr失败-没有错误,但是在这里挂了。