加载用户定义的损失函数+ tf.keras

时间:2019-06-19 14:16:35

标签: python tf.keras

我想重新加载tf.keras给出的训练结果,但是未存储预定义的多变量损失函数,需要将其加载。如何使用y_true,y_pred以外的参数加载此类用户定义的损失函数?

我尝试通过添加custom_objects = {'loss':loss},但仍然无法正常工作。

def loss_function(network_output, network_input):
    def loss(y_true, y_pred):
        # These are some codes doing with network_output, network_input to generate err_sqr,bound_0y,bound_1y,bound_x0,bound_x1
        # Omitted these details
        return K.mean(err_sqr) + bound_0y + bound_1y + bound_x0 + bound_x1
    return loss

# Below are the error code rasing by reload loss function
model = tf.keras.models.load_model(
    save_path,
    custom_objects=None,
    compile=True
)

我希望成功重新加载模型,但实际上它返回错误: “ ValueError:未知损失函数:损失”

0 个答案:

没有答案