tf。海关遗失的地方

时间:2019-04-11 10:24:41

标签: python tensorflow keras loss-function

我正在尝试定义自己的自定义损失,我正在庞加莱圆盘中实现双曲距离,如果其中一个点的半径大于等于1,则不会定义双曲距离。

所以我想使用tf.where来避免nan值,并返回一个mse而不是nan。

这是我的代码:

def hyperbolic_loss(y_true, y_pred):    
    num = 2 * tf.norm(y_true - y_pred, axis = 1)
    densx = (1 - tf.norm(y_pred, axis = 1))
    dendx = (1 - tf.norm(y_true, axis = 1))
    frac = num/(densx * dendx)
    acos = tf.math.acosh(1  + frac)
    acos = tf.diag(acos)
    ret = K.mean(K.mean(acos, axis = -1))
    mse = K.mean(mean_squared_error(y_true, y_pred))

    return tf.where(
        tf.reduce_any(tf.is_nan(acos)),
        mse * 10, 
        ret)

我在网络外部尝试了该方法,并且它可以与任何输入一起使用(即使存在半径> = 1的点),但是当我将该函数用作网络中的损失函数时,可能会发生nan是回到。

报告距离here,r等于1。

我认为问题出在tf.where上,但我不确定。

如何避免nan返回?

0 个答案:

没有答案