Keras自定义损失功能。如何知道损失函数中当前调用了哪些输出样本

时间:2020-07-06 13:33:10

标签: python tensorflow keras

我正在keras中构建自定义损失函数,在此我想在调用y_true之前操纵y_predmean_absolute_error()。这种操作涉及知道当前y_true是哪个样本,因此,我可以轻松地在输出中添加另一列,该列将具有顺序索引。所以我有

  • outputs[:,0],它是所有样本的实际输出变量。

  • outputs[:,1]一个我不想预测的序列索引,而只是在我的自定义损失函数中使用

     def loss(y_true, y_pred):
    
         for i in range(len(y_true)):
             index = y_true[i, 1]
             # some manipulations involving the index
         return mean_absolute_error(new_true, new_pred)
    

如何通过模型忽略索引列输出而又不尝试对其进行预测的方式来实现上述目标。

也许可以更好地解决我的问题。我的最终目标是要知道损失函数中当前有哪些样本输出(通过索引)。

1 个答案:

答案 0 :(得分:-1)

您可以尝试以下方法:

def customLoss(y_true, y_pred):
    
    # you can do your preprocessing here

    def loss(y_true, y_pred):
        for i in range(len(y_true)):
             index = y_true[i, 1]
             # some manipulations involving the index
        return mean_absolute_error(new_true, new_pred)


    return loss(y_true, y_pred)