基于输入数据的Keras中的自定义损失函数

时间:2019-03-31 21:41:42

标签: keras

我正在尝试使用Keras创建自定义损失函数。我想根据输入来计算损失函数并预测神经网络的输出。

我尝试在Keras中使用customloss函数。我认为y_true是我们为训练提供的输出,而y_pred是神经网络的预测输出。以下损失函数与Keras中的“ mean_squared_error”损失相同。

def customloss(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=-1)

除了mean_squared_error损失,我还想使用神经网络的输入来计算自定义损失函数。有没有一种方法可以将输入作为自定义函数的参数发送到神经网络。

谢谢。

2 个答案:

答案 0 :(得分:1)

您可以使用另一个将输入张量作为参数的函数来包装您的自定义损失:

def customloss(x):
    def loss(y_true, y_pred):
        # Use x here as you wish
        err = K.mean(K.square(y_pred - y_true), axis=-1)
        return err

    return loss

然后按如下所示编译模型:

model.compile('sgd', customloss(x))

其中x是您的输入张量。

注意:未经测试。

答案 1 :(得分:1)

对于您提出的问题,我遇到了两种解决方案。

  1. 您可以将输入张量作为参数传递给自定义损失包装函数。
    def custom_loss(i):

        def loss(y_true, y_pred):
            return K.mean(K.square(y_pred - y_true), axis=-1) + something with i...
        return loss

    def baseline_model():
        # create model
        i = Input(shape=(5,))
        x = Dense(5, kernel_initializer='glorot_uniform', activation='linear')(i)
        o = Dense(1, kernel_initializer='normal', activation='linear')(x)
        model = Model(i, o)
        model.compile(loss=custom_loss(i), optimizer=Adam(lr=0.0005))
        return model

the accepted answer here

中也提到了此解决方案
  1. 您可以在输入中使用额外的数据列填充标签,并编写自定义损失。如果您只想从输入中选择一个或几个功能列,这将很有帮助。
    def custom_loss(data, y_pred):

        y_true = data[:, 0]
        i = data[:, 1]
        return K.mean(K.square(y_pred - y_true), axis=-1) + something with i...


    def baseline_model():
        # create model
        i = Input(shape=(5,))
        x = Dense(5, kernel_initializer='glorot_uniform', activation='linear')(i)
        o = Dense(1, kernel_initializer='normal', activation='linear')(x)
        model = Model(i, o)
        model.compile(loss=custom_loss, optimizer=Adam(lr=0.0005))
        return model


    model.fit(X, np.append(Y_true, X[:, 0], axis =1), batch_size = batch_size, epochs=90, shuffle=True, verbose=1)

此解决方案也可以在此thread中找到。

当我不得不在损失中使用输入要素列时,我只使用了第二种方法。我将第一种方法与标量参数一起使用;但我相信张量输入也可以。