R中带有case_when的自定义Keras损失函数

时间:2018-07-21 15:36:49

标签: r tensorflow keras loss-function loss

我尝试为我的Keras模型构建自定义损失函数。我的数据集在类变量中有两个值,分别为1和2。我想优化模型以优化其正确识别1的能力。在这种情况下,是否识别2个正确与否无关紧要。

因此,当匹配变量为2时,其得分为0,而与正确匹配1类的情况相同。如果与正确的1类不匹配,则得到1分。

library(keras)
library(tidyverse)    
OneOnly <- function(y_true, y_pred) {
  K <- backend()
  K$mean(
    case_when(
      y_true == 2 ~ 0,
      y_true == 1 & y_pred == 1 ~ 0,
      TRUE ~ 1
    )
  )
}

当我尝试使用此损失函数训练模型时,出现以下错误:

RuntimeError: Evaluation error: LHS of case 1 (`y_true == 2`) must be a logical, not environment.

从错误消息中,我确实知道我无法在检查中包括变量1和2。有人对如何减轻这种情况有建议吗?

0 个答案:

没有答案