glmnet的整洁预测和混淆矩阵

时间:2018-10-12 18:02:10

标签: r r-caret glmnet

考虑以下示例:

library(quanteda)
library(caret)
library(glmnet)
library(dplyr)

dtrain <- data_frame(text = c("Chinese Beijing Chinese",
                              "Chinese Chinese Shanghai",
                              "Chinese Macao",
                              "Tokyo Japan Chinese"),
                     doc_id = 1:4,
                     class = c("Y", "Y", "Y", "N"))

# now we make the dataframe bigger 
dtrain <- purrr::map_df(seq_len(100), function(x) dtrain)

让我们创建一个稀疏的document-term-matrix并运行一些glmnet

> dtrain <- dtrain %>% mutate(class = as.factor(class))
> mycorpus <- corpus(dtrain,  text_field = 'text')
> trainingdf <- dfm(mycorpus)
> trainingdf
Document-feature matrix of: 400 documents, 6 features (62.5% sparse).

现在我们终于转向套索模型

mymodel <- cv.glmnet(x = trainingdf, y =dtrain$class, 
                     type.measure ='class',
                     nfolds = 3,
                     alpha = 1,
                     parallel = FALSE,
                     family = 'binomial') 

我有两个简单的问题。

如何将预测添加到原始dtrain数据中?确实,

的输出
mypred <- predict.cv.glmnet(mymodel, newx = trainingdf, 
                         s = 'lambda.min', type = 'class')

看起来很糟糕:

> mypred
    1  
1   "Y"
2   "Y"
3   "Y"

如何在此设置中使用caret::confusionMatrix?仅使用以下内容会产生错误:

confusion <- caret::confusionMatrix(data =mypred, 
+                                     reference = dtrain$class)
Error: `data` and `reference` should be factors with the same levels.

谢谢!

1 个答案:

答案 0 :(得分:2)

在每个分类模型中,目标变量的类都必须为factor

例如:

my_data是训练模型的数据集,而my_target是预测变量。

请注意,as.factor(my_data$my_target)会自动为您找到正确的levels

通过这种方式,我的意思是您不需要手动指定levels,但是R会为您完成。

在此处查看我们致电target时的区别:

target <- c("y", "n", "y", "n")
target
#[1] "y" "n" "y" "n" # this is a simple char
as.factor(target)
# [1] y n y n
# Levels: n y # this is a correct format, a factor with levels

这很重要,因为即使您的预测(或测试数据)仅显示target中的两个类别之一,模型也会知道实际的levels可以更多。

您当然可以设置它们:

my_pred <- factor(mypred, levels = c("Y", "N"))

要在数据中添加它们,可以使用

my_data$newpred <- my_pred

library(dplyr)
my_data %>% mutate(newpred = my_pred)