在TensorFlow中训练条件GAN时,鉴别器梯度全为零

时间:2020-05-14 06:01:34

标签: python tensorflow keras generative-adversarial-network dcgan

[更新:我现在没有得到鉴别器的所有零梯度。我的体系结构或各层的初始化可能存在一些问题。我会尝试修复它。]

我正在尝试在TensorFlow中训练条件GAN以使用字幕进行图像合成。这是原始论文:http://arxiv.org/abs/1605.05396

我面临的问题是我为鉴别器的参数获得的梯度全为零。从那时起,我无法回溯问题,因为鉴频器损耗为正,并且梯度是使用预定义函数tf.gradient(discriminator_loss, discriminator_variables)

计算的

此外,我也已经在PyTorch中做到了这一点,但是我在那里没有遇到问题,因为TensorFlow和PyTorch中计算梯度的语法有些不同。因此,我认为问题出在我对TensorFlow的理解中,而不是与Generator和Discriminator的体系结构有关,但我可能是错的。

我要粘贴以下代码的重要部分,如果有谁可以帮助我解决问题。

请让我知道我是否应该发布更多详细信息或删除一些混乱情况。

我的猜测是train_step函数出了点问题,但是我还在下面包括了Generator和Discriminator的体系结构(尽管这可能太多代码无法读取)。

generator = Generator()
discriminator = Discriminator()
criterion = tf.keras.losses.BinaryCrossentropy(from_logits=True)

generator_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5, beta_2=0.999)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-2, beta_1=0.5, beta_2=0.999)

def discriminator_loss(criterion, real_output, fake_output, wrong_caption_output):

    real_labels = tf.ones_like(real_output)    
    fake_labels = tf.zeros_like(fake_output)

    real_loss = criterion(real_labels, real_output)
    fake_loss = criterion(fake_labels, fake_output) + criterion(fake_labels, wrong_caption_output)

    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(criterion, fake_output):
    real_labels = tf.ones_like(fake_output)
    loss = criterion(real_labels, fake_output)
    return loss

@tf.function
def train_step(right_images, right_embed, wrong_images, wrong_embed, noise):

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator.forward(right_embed, noise, training=True) #(B, C, 64, 64)

        real_output, _ = discriminator.forward(right_embed, right_images, training = True)
        fake_output, _ = discriminator.forward(right_embed, tf.stop_gradient(generated_images), training = True)

        wrong_caption_output, _ = discriminator.forward(wrong_embed, right_images, training = True)

        # Disc Losses
        disc_loss = discriminator_loss(criterion, real_output, fake_output, wrong_caption_output)


        ## Pass generated images through the trained disc
        fake_output_1, _ = discriminator.forward(right_embed, generated_images, training = True)

        # Gen loss
        gen_loss = generator_loss(criterion, fake_output_1)

    # Train Disc
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Train Gen
    gen_variables = generator.trainable_variables

    gradients_of_generator = gen_tape.gradient(gen_loss, gen_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen_variables))

    return gen_loss, disc_loss, real_output, fake_output, gradients_of_generator, gradients_of_discriminator

生成器和鉴别器的体系结构:


class Generator(tf.Module):
    def __init__(self):
        super().__init__()

        w_init = tf.random_normal_initializer(stddev=0.02)

        gamma_init = tf.random_normal_initializer(1., 0.02)

        model = tf.keras.Sequential()
        model.add(layers.Dense(projected_embedding_size, input_shape = (embedding_size, ),
                              kernel_initializer = w_init))
        model.add(layers.BatchNormalization(epsilon = 1e-5, 
                                            gamma_initializer = gamma_init))
        model.add(layers.ReLU()) #(B, projected_embedding_size)

        self.projection = model

        model_1 = tf.keras.Sequential()

        model_1.add(layers.Conv2DTranspose(filters = ngf*8, input_shape = (latent_dim, 1, 1),
                                            kernel_size = 4, kernel_initializer = w_init,
                                         strides= 1, padding = 'valid', 
                                              data_format='channels_first', use_bias = False))
        model_1.add(layers.BatchNormalization(epsilon = 1e-5,
                                             gamma_initializer = gamma_init))
        model_1.add(layers.ReLU()) #(B, ngf*8, 4, 4)


        model_1.add(layers.Conv2DTranspose(filters = ngf*4, kernel_size = 4,
                                         strides= 2, kernel_initializer = w_init,
                                         padding = 'same', data_format='channels_first', use_bias = False))
        model_1.add(layers.BatchNormalization(epsilon = 1e-5,
                                             gamma_initializer = gamma_init))
        model_1.add(layers.ReLU()) # (B, ngf*8, 8, 8)


        model_1.add(layers.Conv2DTranspose(filters = ngf*2, kernel_size = 4, 
                                         strides= 2, kernel_initializer = w_init,
                                         padding = 'same', data_format='channels_first', use_bias = False))
        model_1.add(layers.BatchNormalization(epsilon = 1e-5,
                                             gamma_initializer = gamma_init))
        model_1.add(layers.ReLU()) # (B, ngf*2, 16, 16)


        model_1.add(layers.Conv2DTranspose(filters = ngf, kernel_size = 4, 
                                         strides= 2, kernel_initializer = w_init,
                                         padding = 'same', data_format='channels_first', use_bias = False))
        model_1.add(layers.BatchNormalization(epsilon = 1e-5,
                                             gamma_initializer = gamma_init))
        model_1.add(layers.ReLU()) # (B, ngf, 32, 32)


        model_1.add(layers.Conv2DTranspose(filters = img_channels, kernel_size = 4, 
                                         strides= 2, kernel_initializer = w_init,
                                         padding = 'same', data_format='channels_first', use_bias = False))
        model_1.add(layers.Activation('tanh')) # (B, img_channels, 64, 64)

        self.netG = model_1

    def forward(self, embedding, noise, training = True):
        projected_embedding = self.projection(embedding, training = training) # (B, projected_embedding_size)
        noise = noise # (B, noise_dim)

        input = tf.keras.backend.concatenate((noise, projected_embedding), axis = 1) #(B, projected_embedding_size + noise_dim)
        input = tf.keras.backend.reshape(input, shape=(input.shape[0], input.shape[1], 1, 1))
        output = self.netG(input, training = training) # (B, img_channels, 64, 64)

        return output

class Discriminator(tf.Module):
    def __init__(self):
        super().__init__()

        w_init = tf.random_normal_initializer(stddev=0.02)

        gamma_init = tf.random_normal_initializer(1., 0.02)
        model = tf.keras.Sequential() #(B, 64, 64, img_channels)

        model.add(layers.Conv2D(filters = ndf, 
                                input_shape = (generated_img_size, generated_img_size, img_channels),
                                kernel_size = 4,kernel_initializer = w_init,
                                strides= 2, padding = 'same', 
                                data_format='channels_last', use_bias = False))
        model.add(layers.BatchNormalization(epsilon = 1e-5,
                                           gamma_initializer = gamma_init))
        model.add(layers.LeakyReLU(0.2)) #(B, 32, 32, ndf)


        model.add(layers.Conv2D(filters = ndf*2, kernel_size = 4, kernel_initializer = w_init,
                                strides= 2, padding = 'same', 
                                              data_format='channels_last', use_bias = False))
        model.add(layers.BatchNormalization(epsilon = 1e-5,
                                           gamma_initializer = gamma_init))
        model.add(layers.LeakyReLU(0.2)) #(B, 16, 16, ndf*2)


        model.add(layers.Conv2D(filters = ndf*4, kernel_size = 4, kernel_initializer = w_init,
                                strides= 2, padding = 'same', 
                                              data_format='channels_last', use_bias = False))
        model.add(layers.BatchNormalization(epsilon = 1e-5,
                                           gamma_initializer = gamma_init))
        model.add(layers.LeakyReLU(0.2)) #(B, 8, 8, ndf*4)


        model.add(layers.Conv2D(filters = ndf*8, kernel_size = 4, kernel_initializer = w_init,
                                strides= 2, padding = 'same', 
                                              data_format='channels_last', use_bias = False))
        model.add(layers.BatchNormalization(epsilon = 1e-5,
                                           gamma_initializer = gamma_init))
        model.add(layers.LeakyReLU(0.2)) #(B, 4, 4, ndf*8)

        self.netD_1 = model

        # Projection model
        model_1 = tf.keras.Sequential()
        model_1.add(layers.Dense(projected_embedding_size, input_shape=(embedding_size,),
                                kernel_initializer = w_init))
        model_1.add(layers.BatchNormalization(epsilon = 1e-5,
                                             gamma_initializer = gamma_init))
        model_1.add(layers.LeakyReLU(0.2))

        self.projection = model_1

        # Discriminator model 2 - Combining with captions embedding
        model_2 = tf.keras.Sequential()
        model_2.add(layers.Conv2D(filters = 1, input_shape = (4, 4, ndf *8 + projected_embedding_size),
                                            kernel_size = 4, kernel_initializer = w_init,
                                strides= 1, padding = 'valid', 
                                              data_format='channels_last', use_bias = False))
        model_2.add(layers.Activation('sigmoid'))

        self.netD_2 = model_2

    def forward(self, embedding, input, training = True):
        projected_embedding = self.projection(embedding, training = training) # (B, projected_embedding_size)
        projected_embedding = tf.keras.backend.reshape(projected_embedding, 
                                                       shape = (1, 1, projected_embedding.shape[0], 
                                                                projected_embedding.shape[1]))
        projected_embedding = tf.keras.backend.repeat_elements(projected_embedding, rep =4, axis = 0)
        projected_embedding = tf.keras.backend.repeat_elements(projected_embedding, rep =4, axis = 1)

        projected_embedding = projected_embedding  # (4, 4, B, projected_embedding_size)
        projected_embedding = tf.keras.backend.permute_dimensions(projected_embedding, pattern = (2, 0, 1, 3)) # (B, 4, 4, projected_embedding_size)

        # input = (B, C, 64, 64)
        input = tf.keras.backend.permute_dimensions(input, pattern = (0, 2, 3, 1)) # (B, 64, 64, C)

        x_intermediate = self.netD_1(input, training = training) # (B, 4, 4, ndf*8)

        output = tf.keras.backend.concatenate((x_intermediate, projected_embedding), axis = 3)  # (B, 4, 4, ndf*8 + projected_embedding_size)

        output = self.netD_2(output, training = training) # (B, 1, 1, 1)
        output = tf.keras.backend.reshape(output, shape= (output.shape[0], ))

        return output, x_intermediate

0 个答案:

没有答案