[更新:我现在没有得到鉴别器的所有零梯度。我的体系结构或各层的初始化可能存在一些问题。我会尝试修复它。]
我正在尝试在TensorFlow中训练条件GAN以使用字幕进行图像合成。这是原始论文:http://arxiv.org/abs/1605.05396
我面临的问题是我为鉴别器的参数获得的梯度全为零。从那时起,我无法回溯问题,因为鉴频器损耗为正,并且梯度是使用预定义函数tf.gradient(discriminator_loss, discriminator_variables)
此外,我也已经在PyTorch中做到了这一点,但是我在那里没有遇到问题,因为TensorFlow和PyTorch中计算梯度的语法有些不同。因此,我认为问题出在我对TensorFlow的理解中,而不是与Generator和Discriminator的体系结构有关,但我可能是错的。
我要粘贴以下代码的重要部分,如果有谁可以帮助我解决问题。
请让我知道我是否应该发布更多详细信息或删除一些混乱情况。
我的猜测是train_step函数出了点问题,但是我还在下面包括了Generator和Discriminator的体系结构(尽管这可能太多代码无法读取)。
generator = Generator()
discriminator = Discriminator()
criterion = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5, beta_2=0.999)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-2, beta_1=0.5, beta_2=0.999)
def discriminator_loss(criterion, real_output, fake_output, wrong_caption_output):
real_labels = tf.ones_like(real_output)
fake_labels = tf.zeros_like(fake_output)
real_loss = criterion(real_labels, real_output)
fake_loss = criterion(fake_labels, fake_output) + criterion(fake_labels, wrong_caption_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(criterion, fake_output):
real_labels = tf.ones_like(fake_output)
loss = criterion(real_labels, fake_output)
return loss
@tf.function
def train_step(right_images, right_embed, wrong_images, wrong_embed, noise):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator.forward(right_embed, noise, training=True) #(B, C, 64, 64)
real_output, _ = discriminator.forward(right_embed, right_images, training = True)
fake_output, _ = discriminator.forward(right_embed, tf.stop_gradient(generated_images), training = True)
wrong_caption_output, _ = discriminator.forward(wrong_embed, right_images, training = True)
# Disc Losses
disc_loss = discriminator_loss(criterion, real_output, fake_output, wrong_caption_output)
## Pass generated images through the trained disc
fake_output_1, _ = discriminator.forward(right_embed, generated_images, training = True)
# Gen loss
gen_loss = generator_loss(criterion, fake_output_1)
# Train Disc
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
# Train Gen
gen_variables = generator.trainable_variables
gradients_of_generator = gen_tape.gradient(gen_loss, gen_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, gen_variables))
return gen_loss, disc_loss, real_output, fake_output, gradients_of_generator, gradients_of_discriminator
生成器和鉴别器的体系结构:
class Generator(tf.Module):
def __init__(self):
super().__init__()
w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)
model = tf.keras.Sequential()
model.add(layers.Dense(projected_embedding_size, input_shape = (embedding_size, ),
kernel_initializer = w_init))
model.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model.add(layers.ReLU()) #(B, projected_embedding_size)
self.projection = model
model_1 = tf.keras.Sequential()
model_1.add(layers.Conv2DTranspose(filters = ngf*8, input_shape = (latent_dim, 1, 1),
kernel_size = 4, kernel_initializer = w_init,
strides= 1, padding = 'valid',
data_format='channels_first', use_bias = False))
model_1.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model_1.add(layers.ReLU()) #(B, ngf*8, 4, 4)
model_1.add(layers.Conv2DTranspose(filters = ngf*4, kernel_size = 4,
strides= 2, kernel_initializer = w_init,
padding = 'same', data_format='channels_first', use_bias = False))
model_1.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model_1.add(layers.ReLU()) # (B, ngf*8, 8, 8)
model_1.add(layers.Conv2DTranspose(filters = ngf*2, kernel_size = 4,
strides= 2, kernel_initializer = w_init,
padding = 'same', data_format='channels_first', use_bias = False))
model_1.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model_1.add(layers.ReLU()) # (B, ngf*2, 16, 16)
model_1.add(layers.Conv2DTranspose(filters = ngf, kernel_size = 4,
strides= 2, kernel_initializer = w_init,
padding = 'same', data_format='channels_first', use_bias = False))
model_1.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model_1.add(layers.ReLU()) # (B, ngf, 32, 32)
model_1.add(layers.Conv2DTranspose(filters = img_channels, kernel_size = 4,
strides= 2, kernel_initializer = w_init,
padding = 'same', data_format='channels_first', use_bias = False))
model_1.add(layers.Activation('tanh')) # (B, img_channels, 64, 64)
self.netG = model_1
def forward(self, embedding, noise, training = True):
projected_embedding = self.projection(embedding, training = training) # (B, projected_embedding_size)
noise = noise # (B, noise_dim)
input = tf.keras.backend.concatenate((noise, projected_embedding), axis = 1) #(B, projected_embedding_size + noise_dim)
input = tf.keras.backend.reshape(input, shape=(input.shape[0], input.shape[1], 1, 1))
output = self.netG(input, training = training) # (B, img_channels, 64, 64)
return output
class Discriminator(tf.Module):
def __init__(self):
super().__init__()
w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)
model = tf.keras.Sequential() #(B, 64, 64, img_channels)
model.add(layers.Conv2D(filters = ndf,
input_shape = (generated_img_size, generated_img_size, img_channels),
kernel_size = 4,kernel_initializer = w_init,
strides= 2, padding = 'same',
data_format='channels_last', use_bias = False))
model.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model.add(layers.LeakyReLU(0.2)) #(B, 32, 32, ndf)
model.add(layers.Conv2D(filters = ndf*2, kernel_size = 4, kernel_initializer = w_init,
strides= 2, padding = 'same',
data_format='channels_last', use_bias = False))
model.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model.add(layers.LeakyReLU(0.2)) #(B, 16, 16, ndf*2)
model.add(layers.Conv2D(filters = ndf*4, kernel_size = 4, kernel_initializer = w_init,
strides= 2, padding = 'same',
data_format='channels_last', use_bias = False))
model.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model.add(layers.LeakyReLU(0.2)) #(B, 8, 8, ndf*4)
model.add(layers.Conv2D(filters = ndf*8, kernel_size = 4, kernel_initializer = w_init,
strides= 2, padding = 'same',
data_format='channels_last', use_bias = False))
model.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model.add(layers.LeakyReLU(0.2)) #(B, 4, 4, ndf*8)
self.netD_1 = model
# Projection model
model_1 = tf.keras.Sequential()
model_1.add(layers.Dense(projected_embedding_size, input_shape=(embedding_size,),
kernel_initializer = w_init))
model_1.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model_1.add(layers.LeakyReLU(0.2))
self.projection = model_1
# Discriminator model 2 - Combining with captions embedding
model_2 = tf.keras.Sequential()
model_2.add(layers.Conv2D(filters = 1, input_shape = (4, 4, ndf *8 + projected_embedding_size),
kernel_size = 4, kernel_initializer = w_init,
strides= 1, padding = 'valid',
data_format='channels_last', use_bias = False))
model_2.add(layers.Activation('sigmoid'))
self.netD_2 = model_2
def forward(self, embedding, input, training = True):
projected_embedding = self.projection(embedding, training = training) # (B, projected_embedding_size)
projected_embedding = tf.keras.backend.reshape(projected_embedding,
shape = (1, 1, projected_embedding.shape[0],
projected_embedding.shape[1]))
projected_embedding = tf.keras.backend.repeat_elements(projected_embedding, rep =4, axis = 0)
projected_embedding = tf.keras.backend.repeat_elements(projected_embedding, rep =4, axis = 1)
projected_embedding = projected_embedding # (4, 4, B, projected_embedding_size)
projected_embedding = tf.keras.backend.permute_dimensions(projected_embedding, pattern = (2, 0, 1, 3)) # (B, 4, 4, projected_embedding_size)
# input = (B, C, 64, 64)
input = tf.keras.backend.permute_dimensions(input, pattern = (0, 2, 3, 1)) # (B, 64, 64, C)
x_intermediate = self.netD_1(input, training = training) # (B, 4, 4, ndf*8)
output = tf.keras.backend.concatenate((x_intermediate, projected_embedding), axis = 3) # (B, 4, 4, ndf*8 + projected_embedding_size)
output = self.netD_2(output, training = training) # (B, 1, 1, 1)
output = tf.keras.backend.reshape(output, shape= (output.shape[0], ))
return output, x_intermediate