我正在尝试训练DNN,该DNN输出3个值(x,y,z)
,其中x
和y
是我要寻找的对象的坐标,而z
是概率该对象存在
我需要自定义损失功能:
如果z_true<0.5
和x
值无关,那么误差应该等于y
否则错误应类似于(0, 0, sqr(z_true - z_pred))
我正在努力将张量和if语句混合在一起。
答案 0 :(得分:2)
也许这个自定义损失函数的示例将使您启动并运行。它显示了如何将张量与if语句混合使用。
def conditional_loss_function(l):
def loss(y_true, y_pred):
if l == 0:
return loss_funtion1(y_true, y_pred)
else:
return loss_funtion2(y_true, y_pred)
return loss
model.compile(loss=conditional_loss_function(l), optimizer=...)
答案 1 :(得分:1)
使用Keras后端的switch
:https://keras.io/backend/#switch
它类似于tf.cond
如何在Keras中创建自定义损失:Make a custom loss function in keras