自定义损耗/目标函数,在Keras中具有附加变量输入

时间:2018-03-18 06:59:37

标签: tensorflow keras loss-function objective-function

我正在尝试在Keras(tensorflow后端)中创建一个自定义目标函数,其中一个附加参数的值取决于正在训练的批处理。

例如:

def myLoss(self, stateValues):
    def sparse_loss(y_true, y_pred):
        foo = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
        return tf.reduce_mean(foo * stateValues)
    return sparse_loss


self.model.compile(loss=self.myLoss(stateValue = self.stateValue),
        optimizer=Adam(lr=self.alpha))

我的列车功能如下

for batch in batches:
    self.stateValue = computeStateValueVectorForCurrentBatch(batch)
    model.fit(xVals, yVals, batch_size=<num>)

但是,丢失函数中的stateValue没有更新。它只是在model.compile步骤中使用stateValue的值。

我想这可以通过将placeHolder用于stateValue来解决,但我无法弄清楚如何做到这一点。有人可以帮忙吗?

1 个答案:

答案 0 :(得分:1)

您的损失函数没有得到更新,因为keras在每个批次之后都没有编译模型,因此没有使用更新的损失函数。

您可以定义一个自定义回调,它会在每个批次后更新损失值。像这样:

from keras.callbacks import Callback

class UpdateLoss(Callback):
    def on_batch_end(self, batch, logs={}):
        # I am not sure what is the type of the argument you are passing for computing stateValue ??
        stateValue = computeStateValueVectorForCurrentBatch(batch)
        self.model.loss = myLoss(stateValue)