我尝试强制自动编码器中的瓶颈层产生二进制值。我通过在自定义损失函数中使用tensorflow.cond
来做到这一点,对所有非0或1的值进行惩罚。但是这种方法非常慢。有更好的方法来执行此操作吗?
def custom_loss(weight):
def loss(y_true, y_pred):
reconstruction_loss = binary_crossentropy(y_true, y_pred)
def binarize_loss(value):
return tf.cond(tf.reduce_mean(value) > 0.5, lambda: tf.abs(value - 1), lambda: tf.abs(value))
binarized_loss_value = tf.map_fn(binarize_loss, neckLayer.output)
return reconstruction_loss + (K.mean(binarized_loss_value , axis=-1) * weight)
return loss
答案 0 :(得分:2)
我可能会摆脱tf.cond
语句,因为您可以使用简单的算术来做您想做的事情:
def loss(y_true, y_pred):
reconstruction_loss = binary_crossentropy(y_true, y_pred)
binary_neck_loss = tf.abs(0.5 - tf.abs(0.5 - neckLayer.output))
return reconstruction_loss + (K.mean(binary_neck_loss , axis=-1) * weight)
当然,我不知道您的数据的确切形状,但是您应该可以从那里推断出来。