如何强制自动编码器中的瓶颈产生二进制值?

时间:2020-03-20 10:01:23

标签: tensorflow keras autoencoder loss-function

我尝试强制自动编码器中的瓶颈层产生二进制值。我通过在自定义损失函数中使用tensorflow.cond来做到这一点,对所有非0或1的值进行惩罚。但是这种方法非常慢。有更好的方法来执行此操作吗?

def custom_loss(weight):
    def loss(y_true, y_pred):
        reconstruction_loss = binary_crossentropy(y_true, y_pred)

        def binarize_loss(value):
            return tf.cond(tf.reduce_mean(value) > 0.5, lambda: tf.abs(value - 1), lambda: tf.abs(value))

        binarized_loss_value = tf.map_fn(binarize_loss, neckLayer.output)
        return reconstruction_loss + (K.mean(binarized_loss_value , axis=-1) * weight)

    return loss

1 个答案:

答案 0 :(得分:2)

我可能会摆脱tf.cond语句,因为您可以使用简单的算术来做您想做的事情:

def loss(y_true, y_pred):
        reconstruction_loss = binary_crossentropy(y_true, y_pred)

        binary_neck_loss = tf.abs(0.5 - tf.abs(0.5 - neckLayer.output))

        return reconstruction_loss + (K.mean(binary_neck_loss , axis=-1) * weight)

当然,我不知道您的数据的确切形状,但是您应该可以从那里推断出来。

相关问题