GAN网络在Gooogle Cloud TPU上进行缓慢的培训

时间:2020-03-29 10:21:57

标签: python-3.x tensorflow2.0 tensorflow-datasets generative-adversarial-network tpu

我正在尝试在Google Cloud的TPU单元上训练GAN网络。我在https://github.com/deepak112/Keras-SRGAN/blob/master/simplified/train.py处修改了代码,主要移植为使用Dataset API和我自己的数据集。当我运行代码时,它的训练速度非常慢,在TPU上的CPU和内存使用率都很低-好像我做了很多不好的练习和瓶颈。有人看到我在做什么错吗?

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import keras
from keras.applications.vgg19 import VGG19
import keras.backend as K
from keras.layers import add, Dense, Input, Lambda
from keras.layers.advanced_activations import LeakyReLU, PReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers.core import Activation, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model
import numpy as np
from tensorflow.keras.optimizers import SGD, Adam
import tensorflow as tf
import tensorflow_datasets as tfds


image_lr_shape = (64,64,3)
image_hr_shape = (256,256,3)
image_lr_dir = "gs://******/Dataset/Small/*.png"
image_hr_dir = "gs://******/Dataset/Large/*.png"
model_dir = "gs://******/Models"
epochs = 1000
batch_size = 128

def vgg_loss(y_true, y_pred):
    vgg19 = VGG19(include_top=False, weights="imagenet", input_shape=image_hr_shape)
    vgg19.trainable = False
    for l in vgg19.layers:
        l.trainable = False
    loss_model = Model(inputs=vgg19.input,outputs=vgg19.get_layer("block5_conv4").output)
    loss_model.trainable = False
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))

def residual_block(model,kernel_size,filters,strides):
    gen = model
    model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    model = PReLU(alpha_initializer="zeros",alpha_regularizer=None,alpha_constraint=None,shared_axes=[1,2])(model)
    model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    model = add([gen, model])
    return model

def up_sampling_block(model,kernel_size,filters,strides):
    model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
    model = UpSampling2D(size=2)(model)
    model = LeakyReLU(alpha=0.2)(model)
    return model

def discriminator_block(model, filters, kernel_size, strides):
    model = Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    model = LeakyReLU(alpha=0.2)(model)
    return model

def generator_network():
    gen_input = Input(shape=image_lr_shape)
    model = Conv2D(filters=64,kernel_size=9,strides=1,padding="same")(gen_input)
    model = PReLU(alpha_initializer="zeros",alpha_regularizer=None,alpha_constraint=None,shared_axes=[1,2])(model)
    gen_model = model
    for index in range(16):
        model = residual_block(model,3,64,1)
    model = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(model)
    model = BatchNormalization(momentum=0.5)(model)
    model = add([gen_model,model])
    for index in range(2):
        model = up_sampling_block(model,3,256,1)
    model = Conv2D(filters=3,kernel_size=9,strides=1,padding="same")(model)
    model = Activation("tanh")(model)
    generator_model = Model(inputs=gen_input,outputs=model)
    return generator_model

def discriminator_network():
    dis_input = Input(shape=image_hr_shape)
    model = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(dis_input)
    model = LeakyReLU(alpha=0.2)(model)
    model = discriminator_block(model,64,3,2)
    model = discriminator_block(model,128,3,1)
    model = discriminator_block(model,128,3,2)
    model = discriminator_block(model,256,3,1)
    model = discriminator_block(model,256,3,2)
    model = discriminator_block(model,512,3,1)
    model = discriminator_block(model,512,3,2)
    model = Flatten()(model)
    model = Dense(1024)(model)
    model = LeakyReLU(alpha=0.2)(model)
    model = Dense(1)(model)
    model = Activation("sigmoid")(model)
    discriminator_model = Model(inputs=dis_input,outputs=model)
    return discriminator_model

def get_gan_network(discriminator,generator,optimizer):
    discriminator.trainable=False
    gan_input = Input(shape=image_lr_shape)
    x = generator(gan_input)
    gan_output = discriminator(x)
    gan = Model(inputs=gan_input,outputs=[x,gan_output])
    gan.compile(loss=[vgg_loss,"binary_crossentropy"],loss_weights=[1.,1e-3], optimizer=optimizer)
    return gan

def load_image(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img,channels=3)
    img = tf.image.convert_image_dtype(img,tf.float32)
    return img

#============================================WARNING: Dataset size hardcoded!
print("Connecting to TPU...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="******",zone="******",project="******")
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
    print("Loading dataset...")
    image_hr = tf.data.Dataset.list_files(image_hr_dir,shuffle=False).map(load_image).shuffle(6777,seed=12345679,reshuffle_each_iteration=True).batch(batch_size).prefetch(4)
    image_lr = tf.data.Dataset.list_files(image_lr_dir,shuffle=False).map(load_image).shuffle(6777,seed=12345679,reshuffle_each_iteration=True).batch(batch_size).prefetch(4)
    batch_count = int(6777/batch_size)
    print("Compiling neural network---")
    generator = generator_network()
    discriminator = discriminator_network()
    adam = Adam(learning_rate=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-8)
    generator.compile(loss=vgg_loss,optimizer=adam)
    discriminator.compile(loss="binary_crossentropy",optimizer=adam)
    gan = get_gan_network(discriminator,generator,adam)
    print("Starting training...")
    for e in range(1,epochs+1):
        print("-"*15,"Epoch %d" % e, "-"*15)
        for b in range(batch_count):
            print("-"*10,"Batch %d" % b,"-"*10)
            batch_hr = np.stack(tfds.as_numpy(image_hr.take(1)))
            batch_lr = np.stack(tfds.as_numpy(image_lr.take(1)))
            batch_sr = generator.predict(batch_lr)
            real_y = tf.random.uniform(shape=(batch_size,1),minval=0.8,maxval=1)
            fake_y = tf.random.uniform(shape=(batch_size,1),minval=0,maxval=0.2)
            discriminator.trainable = True
            d_loss_real = discriminator.train_on_batch(batch_hr,real_y)
            d_loss_fake = discriminator.train_on_batch(batch_sr,fake_y)
            batch_hr = np.stack(tfds.as_numpy(image_hr.take(1)))
            batch_lr = np.stack(tfds.as_numpy(image_lr.take(1)))
            gan_y = tf.random.uniform(shape=(batch_size,1),minval=0.8,maxval=1)
            discriminator.trainable = False
            loss_gan = gan.train_on_batch(batch_lr,[batch_hr,gan_y])
        print("Loss d_real,Loss d_fake, Loss network")
        print(d_loss_real,d_loss_fake,loss_gan)
        os.makedirs(os.path.join(model_dir,str(e)))
        generator.save(os.path.join(model_dir,str(e),"gen.h5"))
        discriminator.save(os.path.join(model_dir,str(e),"dis.h5"))
        gan.save(os.path.join(model_dir,str(e),"gan.h5"))

更新:我添加了更多的调试打印,它在batch_hr失败-没有错误,但是在这里挂了。

0 个答案:

没有答案