Keras自定义丢失功能:访问当前输入模式

时间:2017-09-28 08:34:20

标签: python tensorflow keras loss

在Keras(使用Tensorflow后端)中,我的自定义丢失功能可以使用当前输入模式吗?

当前输入模式被定义为用于产生预测的输入向量。例如,请考虑以下事项:X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42, shuffle=False)。然后当前输入模式是与y_train相关联的当前X_train向量(在损失函数中称为y_true)。

在设计自定义损失函数时,我打算优化/最小化需要访问当前输入模式的值,而不仅仅是当前预测。

我看了https://github.com/fchollet/keras/blob/master/keras/losses.py

我也查看了“Cost function that isn't just y_pred, y_true?

我也熟悉以前的例子来产生定制的损失函数:

import keras.backend as K

def customLoss(y_true,y_pred):
    return K.sum(K.log(y_true) - K.log(y_pred))

据推测(y_true,y_pred)在其他地方定义。我已经看了一下源代码没有成功,我想知道我是否需要自己定义当前的输入模式,或者我的丢失函数是否已经可以访问它。

2 个答案:

答案 0 :(得分:15)

您可以将损失函数包装为内部函数并将输入张量传递给它(通常在将其他参数传递给损失函数时)。

def custom_loss_wrapper(input_tensor):
    def custom_loss(y_true, y_pred):
        return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
    return custom_loss

input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

您可以验证input_tensor和损失值(主要是K.mean(input_tensor)部分)会随着向模型传递不同的X而更改。

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y)  # => 1.1974642

X *= 1000
model.test_on_batch(X, y)  # => 511.15466

答案 1 :(得分:1)

您可以使用 add_loss 将外部层传递给您的损失,在您的情况下是输入张量。

这里有一个例子:

def CustomLoss(y_true, y_pred, input_tensor):
    return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)

X = np.random.uniform(0,1, (1000,10))
y = np.random.randint(0,2, 1000)

inp = Input(shape=(10,))
hidden = Dense(100, activation='relu')(inp)
out = Dense(1, activation='sigmoid')(hidden)
target = Input((1,))
model = Model([inp,target], out)

model.add_loss( CustomLoss( target, out, inp ) )
model.compile(loss=None, optimizer='adam')
model.fit(x=[X,y], y=None, epochs=3)

在推理模式下使用模型(从输入中移除目标)

final_model = Model(model.input[0], model.output)
final_model.predict(X)