VAE的损失值非常高

时间:2019-06-02 12:30:22

标签: tensorflow

这是代码的完整实现,几乎可以实现4的幂次精度

input_image = 784 
latent_vector_dimension = 2
epochs = 100
batch_size = 32

#encoder
input_images = Input(shape=(784,))
encoder = Dense(256,activation='relu')(input_images)
encoder = Dense(512,activation='relu')(encoder)

mean = Dense(latent_vector_dimension,activation= 'linear')(encoder)
log_stddev = Dense(latent_vector_dimension,activation= 'linear')(encoder)

#latent_vector
def create_latent_vector(args):
    mean,log_stddev = args
    epsilon = K.random_normal(shape = (latent_vector_dimension,), mean=0,stddev = 1.0)
    return mean + K.exp(log_stddev*0.5)*epsilon

z = Lambda(create_latent_vector)([mean,log_stddev])

#decoder 
decoder_1 = Dense(128,activation='relu')(z)
decoder_3 = Dense(512,activation='relu')(decoder_1)
decoder_4 = Dense(784,activation='sigmoid')(decoder_3)

vae = Model(input_images,decoder_4)
def vae_loss(original_image,reconstructed_image):
    # E P(X|Z)
    reconstruction_loss = K.sum(K.binary_crossentropy(reconstructed_image, original_image), axis=1)
    kl_loss = 0.5*K.sum(K.exp(log_stddev) + K.square(mean) -1 - log_stddev,axis=1)
    print (reconstruction_loss.shape)
    print (kl_loss.shape)
    return K.mean(kl_loss+reconstruction_loss)
vae.compile(optimizer='adam', loss=vae_loss,metrics= ['accuracy'])

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = tf.reshape(x_train, (-1,784))
vae.fit(x_train,x_train,epochs=epochs,steps_per_epoch = batch_size)

我知道该实现背后的理论,并且已经了解了VAE所具备的能力的90%。但是它的实现让我头疼。我遵循了做同样的教程,并且也有同样的问题。如果我的直觉是正确的,我一直在弄乱vae_loss函数,这可能会导致问题,但是我确信它也能正常工作。我不太确定它可能存在的其他问题,包括编码器,解码器等的体系结构。但是我的猜测是vae_loss函数。请问有人可以在这个问题上帮助我吗?我已经挠头了三天了。

0 个答案:

没有答案