MNIST DC-GAN所有梯度均为零

时间:2020-07-15 17:06:36

标签: python mnist generative-adversarial-network jax

我正在尝试使用Flax作为指南,在TF example框架上为MNIST构建DC-GAN。网络本身在技术上是可行的,但是生成器和鉴别器均不会更新,因为它们的梯度始终为零。我已经确定权重已正确初始化,并尝试提高学习率,但这没有帮助。我只能怀疑问题出在网络本身的体系结构中,但是它是逐行复制的,除了生成器中的BatchNormalizations和鉴别器中的Dropouts之外。

class generator_class(nn.Module):
    def apply(self, x):
        x = nn.Dense(x, features=7*7*256, bias_init=initializers.zeros)
        x = lrelu(x)
        x = x.reshape((-1, 7, 7, 256))
        x = nn.ConvTranspose(x, features=128, kernel_size=(5, 5), strides=(1, 1), bias=False)
        x = lrelu(x)
        x = nn.ConvTranspose(x, features=64, kernel_size=(5, 5), strides=(2, 2), bias=False)
        x = lrelu(x)
        x = nn.ConvTranspose(x, features=1, kernel_size=(5, 5), strides=(2, 2), bias=False)
        x = nn.tanh(x)
    return x

class discriminator_class(nn.Module):
    def apply(self, x):
        x = nn.Conv(x, features=64, kernel_size=(5, 5), strides=(2,2))
        x = lrelu(x)
        x = nn.Conv(x, features=128, kernel_size=(5, 5), strides=(2,2))
        x = lrelu(x)
        x = x.reshape((x.shape[0], -1)) #flatten
        x = nn.Dense(x, features=1)
    return x

_, init_params = generator_class.init_by_shape(random.PRNGKey(0), [((100,), jnp.float32)])
generator = nn.Model(generator_class, init_params)

_, init_params = discriminator_class.init_by_shape(random.PRNGKey(0), [((1, 28, 28, 1), jnp.float32)])
discriminator = nn.Model(discriminator_class, init_params)

@jax.vmap
def binary_cross_entropy(logits, labels):
    logits = nn.log_sigmoid(logits)
    return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))

@jax.jit
def train_step(generator_optimizer, discriminator_optimizer, images):

    noise = jax.random.normal(random.PRNGKey(0), shape = [256, 100])
    generated_images = gen(noise)
    real_output = disc(images)
    fake_output = disc(generated_images)

    def generator_loss(generator):
        return binary_cross_entropy(jnp.ones_like(fake_output), fake_output).mean(), generated_images

    def discriminator_loss(discriminator):
        real_loss = binary_cross_entropy(jnp.ones_like(real_output), real_output).mean()
        fake_loss = binary_cross_entropy(jnp.zeros_like(fake_output), fake_output).mean()
        total_loss = real_loss + fake_loss
        return total_loss, fake_output

    grad_fn_gen = jax.value_and_grad(generator_loss, has_aux=True)
    (_, preds), grad_gen = grad_fn_gen(generator_optimizer.target)
    generator_optimizer = generator_optimizer.apply_gradient(grad_gen)    

    grad_fn_disc = jax.value_and_grad(discriminator_loss, has_aux=True)
    (_, preds), grad_disc = grad_fn_disc(discriminator_optimizer.target)
    discriminator_optimizer = discriminator_optimizer.apply_gradient(grad_disc)

return generator_optimizer, discriminator_optimizer

0 个答案:

没有答案