R:班级数不等于2

时间:2015-07-22 08:39:30

标签: r

我在尝试为随机森林模型生成提升图时遇到以下错误:

预测错误(crs $ pr,no.miss):   类的数量不等于2。 ROCR目前仅支持二进制分类任务的评估。

我的代码如下:     #评估模型性能。

# Lift Chart: requires the ROCR package.

library(ROCR)

# Obtain predictions for the rf model on TBA_jul21c.csv [**train**].

crs$pr <- predict(crs$rf, newdata=na.omit(crs$dataset[crs$sample,    c(crs$input, crs$target)]), type="prob")[,2]

# Remove observations with missing target.

no.miss   <- na.omit(na.omit(crs$dataset[crs$sample, c(crs$input,   crs$target)])$targetvar)
miss.list <- attr(no.miss, "na.action")
attributes(no.miss) <- NULL

if (length(miss.list))
{
 pred <- prediction(crs$pr[-miss.list], no.miss)
} else
{
  pred <- prediction(crs$pr, no.miss)
}

# Convert rate of positive predictions to percentage.

per <- performance(pred, "lift", "rpp")
per@x.values[[1]] <- per@x.values[[1]]*100

# Plot the lift chart.
plot(per, col="#CC0000FF", lty=1, xlab="Caseload (%)", add=FALSE)

我的数据集有127个变量,包括带有Y / N标签的目标变量。我正在阅读数据集如下:

crs$dataset <- read.csv("file:///C:/MastersC/TBA_jul21c.csv",    na.strings=c(".", "NA", "", "?"), strip.white=TRUE, encoding="UTF-8")

#============================================================
# Rattle timestamp: 2015-07-22 08:36:53 x86_64-w64-mingw32 

# Note the user selections. 

# Build the training/validate/test datasets.

set.seed(crv$seed) 
crs$nobs <- nrow(crs$dataset) # 72824 observations 
crs$sample <- crs$train <- sample(nrow(crs$dataset), 0.7*crs$nobs) # 50976   observations
crs$validate <- sample(setdiff(seq_len(nrow(crs$dataset)), crs$train),   0.15*crs$nobs) # 10923 observations
crs$test <- setdiff(setdiff(seq_len(nrow(crs$dataset)), crs$train),     crs$validate) # 10925 observations

我认为这与目标变量有关,但我不确定。任何帮助将不胜感激,因为我已经坚持了几个小时。通过这是我的第一个stackoverflow帖子,所以如果我没有正确发布,请随意惩罚我。 非常感谢

1 个答案:

答案 0 :(得分:0)

使用预测函数的2个输入创建Predobject:

  1. 预测的列表(向量)(模型输出,赔率或 每个成员的概率),
  2. 标签的列表(矢量):0或1
  3. 如下所示:

    library(ROCR)
    data(ROCR.simple)
    pred <- prediction( ROCR.simple$predictions, ROCR.simple$labels)
    

    在问题的情况下,第二个参数(上述问题中的no.miss)很可能不是二进制信息列表(0,1)。

    如果您收到此错误,请检查您的第二个参数。我有同样的错误并提供正确的信息解决了它。