尽管准确的鉴别器,Keras GAN(发电机)训练不佳

时间:2018-02-13 10:22:58

标签: tensorflow keras

我已经尝试过几天对此进行整理,在论坛等上发现了许多建议,现在欢迎任何有关错误的建议!

我正在尝试进行我的第一次GAN训练 - 一个简单的前馈深度网 - 非常类似于使用MNIST数据集,但是使用了源自VCTK-Corpus的频谱功率窗口(大小(1,513))。

您可以从下面的Tensorboard图表中看到网络似乎正在进行交互,并且正在进行某种培训: Tensorboard graph overviewTensorboard graph zoom

然而,结果很差且噪音很大:generated and validation comparison

生成器采用正常噪声(通常为30到100个向量),平均值为0,stdev为0.5。

def gan_generator(x_shape, frame_size):
    g_input = Input(shape=x_shape)
    H = BatchNormalization()(g_input)
    H = Dense(128)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(128)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(256)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(256)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(256)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    out = Dense(frame_size[1], activation='linear')(H)

    generator = Model(g_input, out)
    generator.summary()
    return generator

鉴别器确定生成帧的单热分类: (不确定这里的批量标准化 - 我已经读过它,如果你将实际和生成混合成一个批次就不应该使用它。但是,尽管有更高的损失,生成器使用它比没有更有说服力的结果。 )

def gan_discriminator(input_shape):
    d_input = Input(shape=input_shape)
    H = Dropout(0.1)(d_input)
    H = Dense(256)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(128)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(100)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(100)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Reshape((1, -1))(H)
    d_V = Dense(2, activation='softmax')(H)

    discriminator = Model(d_input,d_V)
    discriminator.summary()
    return discriminator

GAN很简单:

def init_gan(generator, discriminator):
    x = Input(shape=generator.inputs[0].shape[1:])

    #Generator makes a prediction
    pred = generator(x)

    #Discriminator attempts to categorise prediction
    y = discriminator(pred)

    GAN = Model(x, y)
    return GAN

一些训练变量:

  • GAN(发电机):Adam,lr = 1e-4,categorical_crossentropy
  • 鉴别人:Adam,lr = 1e-3,categorical_crossentropy
  • 批量大小:约8000个样本
  • 小批量(重量更新周期):32

训练循环:

#Pre-training Discriminator Network
#Load new batch of real frames
frames = load_data(data_dir)
frames_label = np.zeros((frames.shape[0], 1, 2))
frames_label[:, :, 0] = 1 #mark as real frames

#Generate Frames from noise vector
X_noise = noisegen((frames.shape[0], 1, n_noise))
generated_frames = generator.predict(X_noise)
generated_label = np.zeros((generated_frames.shape[0], 1, 2))
generated_label[:, :, 1] = 1 #mark as false frames

#Prep Data - concat real and false data
dis_batch_x = np.concatenate((frames, generated_frames), axis=0)
dis_batch_y = np.concatenate((frames_label, generated_label), axis=0)

#Make discriminator trainable and train for 8 epochs
make_trainable(discriminator, True)
discriminator.compile(optimizer=dis_optimizer, loss=dis_loss)
fit_model(discriminator, dis_batch_x, dis_batch_y, 8)

#Training Loop
for d in range(data_sets):
    print "Starting New Dataset: {0}/{1}".format(d+1, data_sets)

    """ Fit Discriminator """
    #Load new batch of real frames
    frames = load_data(data_dir)
    frames_label = np.zeros((frames.shape[0], 1, 2))
    frames_label[:, :, 0] = 1 #mark as real frames

    #Generate Frames from noise vector
    X_noise = noisegen((frames.shape[0], 1, n_noise))
    generated_frames = generator.predict(X_noise)
    generated_label = np.zeros((generated_frames.shape[0], 1, 2))
    generated_label[:, :, 1] = 1 #mark as false frames

    #Prep Data - concat real and false data
    dis_batch_x = np.concatenate((frames, generated_frames), axis=0)
    dis_batch_y = np.concatenate((frames_label, generated_label), axis=0)

    #Make discriminator trainable & fit
    make_trainable(discriminator, True)
    discriminator.compile(optimizer=dis_optimizer, loss=dis_loss)
    fit_model(discriminator, dis_batch_x, dis_batch_y)


    """ Fit Generator """
    #Prep Data
    X_noise = noisegen((frames.shape[0], 1, n_noise))
    generated_label = np.zeros((generated_frames.shape[0], 1, 2))
    generated_label[:, :, 1] = 1 #mark as false frames

    make_trainable(discriminator, False)
    GAN.layers[2].trainable = False #done twice just to be sure
    GAN.compile(optimizer=GAN_optimizer, loss=GAN_loss) 
    fit_model(GAN, X_noise, generated_label)

最后是一点系统信息:

  • OSX 10.12
  • Tensorflow 1.5.0(GPU)
  • Keras 2.1.3
  • Python 2.7

非常感谢提前!

1 个答案:

答案 0 :(得分:0)

实际上解决方案是我没有在Generator训练中交换我的True / False类(建议https://github.com/soumith/ganhacks),我认为这有效地使它成为渐变上升。

对此进行澄清会很不错。

相关问题