Wasserstein GAN问题,最后一个鉴别层和剪切

时间:2019-12-28 17:03:57

标签: machine-learning keras computer-vision conv-neural-network gan

当我在最后一个使用权重裁剪的鉴别器层中使用线性激活或不激活时,鉴别器精度变为1,生成器变为0。如果我删除权重剪辑,则生成器精度将变为1,而鉴别器将变为0,大约进行300次迭代。但是,当我使用S型激活作为鉴别器的最后一层时,会产生削波发生器精度变为1,而没有削波,则发生器的损耗会卡住,而精度会达到0.5左右。 注意-在所有情况下,都会产生结果并显示所有警告:tensorflow:可训练砝码与收集的可训练砝码之间的差异,您是否设置了model.trainable而不在之后调用model.compile

代码在这里给出,请不要介意在任何地方复制和粘贴缩进-

class WGAN():

  def  __init__(self, 
              input_dim,
              disc_filter,
              disc_kernel,
              disc_strides,
              disc_dropout,
              disc_lr,
              gen_filter,
              gen_kernel,
              gen_strides,
              gen_upsample,
              gen_lr,
              z_dim,
              batch_size):


      self.input_dim = input_dim
      self.disc_filter = disc_filter
      self.disc_kernel = disc_kernel
      self.disc_strides = disc_strides
      self.disc_dropout = disc_dropout
      self.disc_lr = disc_lr
      self.gen_filter = gen_filter
      self.gen_kernel = gen_kernel
      self.gen_strides = gen_strides
      self.gen_upsample = gen_upsample
      self.gen_lr = gen_lr
      self.z_dim = z_dim
      self.batch_size = batch_size
      self.weight_init = RandomNormal(mean=0., stddev=0.02)

      self.d_losses = []
      self.g_losses = []
      self.epoch = 0

      self.Discriminator()
      self.Generator()

      self.full_model()

  def wasserstein(self, y_true, y_pred):
    return -K.mean(y_true * y_pred)

  def Discriminator(self):

    disc_input = Input(shape=self.input_dim, name='discriminator_input')
    x = disc_input

    for i in range(len(self.disc_filter)):
        x = Conv2D(filters=self.disc_filter[i], kernel_size=self.disc_kernel[i], strides=self.disc_strides[i], padding='same', name='disc_'+str(i))(x)
        x = LeakyReLU()(x)
        x = Dropout(self.disc_dropout)(x)
        x = BatchNormalization()(x)

    x = Flatten()(x)
    disc_output = Dense(1, activation='sigmoid', kernel_initializer = self.weight_init)(x)
    self.discriminator = Model(disc_input, disc_output)

  def Generator(self):

    gen_input = Input(shape=(self.z_dim,), name='generator_input')
    x = gen_input

    x = Dense(7*7*self.batch_size, kernel_initializer = self.weight_init)(x)
    x = LeakyReLU()(x)
    x = BatchNormalization()(x)
    x = Reshape(target_shape=(7,7,self.batch_size))(x)

    for i in range(len(self.gen_filter)):
        if self.gen_upsample[i]==2:
            x = UpSampling2D(size=self.gen_upsample[i], name='upsample_'+str(i/2))(x)
            x = Conv2D(filters=self.gen_filter[i], kernel_size=self.gen_kernel[i], strides=self.gen_strides[i], padding='same', name='gen_'+str(i))(x)

        else:
            x = Conv2DTranspose(filters=self.gen_filter[i], kernel_size=self.gen_kernel[i], strides=self.gen_strides[i], padding='same', name='gen_'+str(i))(x)

        if i<len(self.gen_filter)-1:
            x = BatchNormalization()(x)
            x = LeakyReLU()(x)

        else:
            x = Activation("tanh")(x)

    gen_output = x
    self.generator = Model(gen_input, gen_output)


  def set_trainable(self, model, val):
    model.trainable=val
    for l in model.layers:
        l.trainable=val

  def full_model(self):

    ### COMPILE DISCRIMINATOR
    self.discriminator.compile(optimizer= Adam(self.disc_lr), loss = self.wasserstein, metrics=['accuracy'])

    ### COMPILE THE FULL GAN

    self.set_trainable(self.discriminator, False)
    self.discriminator.compile(optimizer= Adam(self.disc_lr), loss = self.wasserstein, metrics=['accuracy'])

    model_input = Input(shape=(self.z_dim,), name='model_input')
    model_output = self.discriminator(self.generator(model_input))
    self.model = Model(model_input, model_output)

    self.model.compile(optimizer= Adam(self.disc_lr), loss = self.wasserstein, metrics=['accuracy'])

    self.set_trainable(self.discriminator, True)

  def train_generator(self, batch_size):
    valid = np.ones((batch_size,1))
    noise = np.random.normal(0, 1, (batch_size, self.z_dim))
    return self.model.train_on_batch(noise, valid)

  def train_discriminator(self, x_train, batch_size, using_generator):
    valid = np.ones((batch_size,1))
    fake = np.zeros((batch_size,1))

    if using_generator:
        true_imgs = next(x_train)[0]
        if true_imgs.shape[0] != batch_size:
            true_imgs = next(x_train)[0]
    else:
        idx = np.random.randint(0, x_train.shape[0], batch_size)
        true_imgs = x_train[idx]

    noise = np.random.normal(0, 1, (batch_size, self.z_dim))
    gen_imgs = self.generator.predict(noise)

    d_loss_real, d_acc_real =  self.discriminator.train_on_batch(true_imgs, valid)
    d_loss_fake, d_acc_fake =  self.discriminator.train_on_batch(gen_imgs, fake)
    d_loss = 0.5 * (d_loss_real + d_loss_fake)
    d_acc = 0.5 * (d_acc_real + d_acc_fake)

    for l in self.discriminator.layers:
        weights = l.get_weights()
        weights = [np.clip(w, -0.01, 0.01) for w in weights]
        l.set_weights(weights)

    return [d_loss, d_loss_real, d_loss_fake, d_acc, d_acc_real, d_acc_fake]

  def train(self, x_train, batch_size, epochs, print_every_n_batches = 50, using_generator = False):

    for epoch in range(self.epoch, self.epoch + epochs):

        d = self.train_discriminator(x_train, batch_size, using_generator)
        g = self.train_generator(batch_size)
        if self.epoch % print_every_n_batches == 0:
            print ("%d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f] [G acc: %.3f]" % (epoch, d[0], d[1], d[2], d[3], d[4], d[5], g[0], g[1]))

        self.d_losses.append(d)
        self.g_losses.append(g)

        self.epoch+=1

0 个答案:

没有答案