需要使用if语句的自定义损失函数

时间:2019-01-11 19:51:33

标签: tensorflow machine-learning keras loss-function

我正在尝试训练DNN,该DNN输出3个值(x,y,z),其中xy是我要寻找的对象的坐标,而z是概率该对象存在

我需要自定义损失功能:

如果z_true<0.5x值无关,那么误差应该等于y

否则错误应类似于(0, 0, sqr(z_true - z_pred))

我正在努力将张量和if语句混合在一起。

2 个答案:

答案 0 :(得分:2)

也许这个自定义损失函数的示例将使您启动并运行。它显示了如何将张量与if语句混合使用。

 def conditional_loss_function(l):
        def loss(y_true, y_pred):
            if l == 0: 
                return loss_funtion1(y_true, y_pred)
            else: 
                return loss_funtion2(y_true, y_pred)
        return loss

 model.compile(loss=conditional_loss_function(l), optimizer=...)

答案 1 :(得分:1)

使用Keras后端的switchhttps://keras.io/backend/#switch 它类似于tf.cond 如何在Keras中创建自定义损失:Make a custom loss function in keras