我正在尝试在Keras中实现以下形式的自定义损失功能:
这是我用来实现它的代码:
def custom_loss_function(y_true, y_pred):
a = y_pred[..., 0]
b = y_pred[..., 1]
def f(y, x):
return ((y_true-x)**2) * ( (b/(a)**b) * (x**(b-1)) * K.exp(-(x/a)**b) )
x = K.constant([ 0., 5000. ])
return K.mean(tf.contrib.integrate.odeint_fixed( f, 0., x, method = "rk4" ))
我试图在损失函数之外使用tf.contrib.integrate.odeint_fixed,并且实际上它可以工作。但是,一旦在损失函数内使用,它就会停止工作。
在此先感谢您提出任何建议。