我编写了一个自定义损失函数,该函数对预测的错误符号最不利,对较大的误差为方差,对于较小的差异则为绝对损失:
class CustomLoss(keras.losses.Loss):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, y_true, y_pred):
error = y_true - y_pred
wrong_direction = y_true * y_pred < 0
small_diff = tf.abs(error) < 0.005
large_loss = 10 * tf.square(error)
square_loss = tf.square(error)
linear_loss = tf.abs(error)
return tf.where(wrong_direction, large_loss, tf.where(small_diff, linear_loss, square_loss))
def get_config(self):
base_config = super().get_config()
return {**base_config}
这很好,但是培训非常慢。我以为可能是因为我没有使用tf.cond(),所以尝试以这种方式实现它,但是它不起作用:
linear_loss = tf.cond(tf.abs(error) < 0.005, tf.abs(error), tf.square(error))
return tf.where(wrong_direction, large_loss, linear_loss)
有什么想法如何正确实施它或会影响培训时间吗?