我正在尝试定义自己的自定义损失,我正在庞加莱圆盘中实现双曲距离,如果其中一个点的半径大于等于1,则不会定义双曲距离。
所以我想使用tf.where来避免nan值,并返回一个mse而不是nan。
这是我的代码:
def hyperbolic_loss(y_true, y_pred):
num = 2 * tf.norm(y_true - y_pred, axis = 1)
densx = (1 - tf.norm(y_pred, axis = 1))
dendx = (1 - tf.norm(y_true, axis = 1))
frac = num/(densx * dendx)
acos = tf.math.acosh(1 + frac)
acos = tf.diag(acos)
ret = K.mean(K.mean(acos, axis = -1))
mse = K.mean(mean_squared_error(y_true, y_pred))
return tf.where(
tf.reduce_any(tf.is_nan(acos)),
mse * 10,
ret)
我在网络外部尝试了该方法,并且它可以与任何输入一起使用(即使存在半径> = 1的点),但是当我将该函数用作网络中的损失函数时,可能会发生nan
是回到。
报告距离here,r等于1。
我认为问题出在tf.where上,但我不确定。
如何避免nan
返回?