变分自动编码器 - 代码的比较分析

时间:2018-05-03 08:22:03

标签: keras autoencoder

你能否评论两个版本的变分自动编码器丢失,并告诉我他们为什么给我不同的结果?

数据集:

data1 = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
   1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
   0, 0, 0, 0, 0, 0, 0, 0], dtype='int32')

data2 = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
   1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
   0, 0, 0, 0, 0, 0, 0, 0], dtype='int32')


data3 = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
   1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
   0, 0, 0, 0, 0, 0, 0, 0], dtype='int32')

每个样品100个,所以我有300个样品。

代码1:

def vae_loss(x, x_decoded_mean):
    xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var))
    loss = xent_loss + kl_loss
    return loss

vae.compile(optimizer='rmsprop', loss=vae_loss)

代码2:

    def zero_loss(y_true, y_pred):
        return K.zeros_like(y_pred)

    class CustomVariationalLayer(Layer):
        def __init__(self, **kwargs):
            self.is_placeholder = True
            super(CustomVariationalLayer, self).__init__(**kwargs)
    def vae_loss(self, x, x_decoded_mean):
        xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
K.exp(z_log_var), axis=-1)
        return K.mean(xent_loss + kl_loss)

    def call(self, inputs):
        x = inputs[0]
        x_decoded_mean = inputs[1]
        loss = self.vae_loss(x, x_decoded_mean)
        self.add_loss(loss, inputs=inputs)

        return K.ones_like(x)

loss_layer = CustomVariationalLayer()([x, x_decoded_mean])
vae = Model(x, [loss_layer])
vae.compile(optimizer='rmsprop', loss=[zero_loss])

结果如此不同,我看不到哪里?潜在的维度是不同的。代码2显示了组之间的分离,而代码1则没有。 代码1,vae.predict ...不准确,代码2给我1所有功能。 代码2为我提供了准确的代码反馈:

sent_encoded = encoder.predict(np.array(test), batch_size = batch_size)
sent_decoded = generator.predict(sent_encoded)

并且代码1根本不准确。

两个实验都有相同的层次。那么,再一次,如上所述,数据集的不同之处和最佳解决方案是什么?

0 个答案:

没有答案