TensorFlow中的自定义损失函数(带阶跃损失)

时间:2020-08-03 16:08:46

标签: python python-3.x tensorflow keras tensorflow2.0

我编写了一个自定义损失函数,该函数对预测的错误符号最不利,对较大的误差为方差,对于较小的差异则为绝对损失:


class CustomLoss(keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, y_true, y_pred):
        error = y_true - y_pred
        wrong_direction = y_true * y_pred < 0
        small_diff = tf.abs(error) < 0.005

        large_loss = 10 * tf.square(error)
        square_loss = tf.square(error)
        linear_loss = tf.abs(error)

        return tf.where(wrong_direction, large_loss, tf.where(small_diff, linear_loss, square_loss))

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}

这很好,但是培训非常慢。我以为可能是因为我没有使用tf.cond(),所以尝试以这种方式实现它,但是它不起作用:

linear_loss = tf.cond(tf.abs(error) < 0.005, tf.abs(error), tf.square(error))

return tf.where(wrong_direction, large_loss, linear_loss)

有什么想法如何正确实施它或会影响培训时间吗?

0 个答案:

没有答案