列车未使用的参数R语言出错

时间:2018-06-05 11:13:04

标签: r r-caret

我正在尝试运行R脚本,因为我遇到了火车错误:

  

火车出错(frm,data = mushrooms [train_idx,],method =" rpart",   trControl = trControl,:未使用的参数(data =   蘑菇[train_idx,],方法=" rpart",trControl = trControl,   tuneGrid = rpart.grid,metric ="准确度")

我试过的是,请在下面找到。对于数据集,请参阅附件。 dataset to download

caret, ggplot2, dplyr, gridExtra, gmodels, ggparallel, rpart.plot, sqldf, readxl
mushrooms <- read_excel("~/Desktop/Rlang/Mushroom.xlsx")
View(mushrooms)
fields <- c("class",
            "cap_shape",
            "cap_surface",
            "cap_color",
            "bruises",
            "odor",
            "gill_attachment",
            "gill_spacing",
            "gill_size",
            "gill_color",
            "stalk_shape",
            "stalk_root",
            "stalk_surface_above_ring",
            "stalk_surface_below_ring",
            "stalk_color_above_ring",
            "stalk_color_below_ring",
            "veil_type",
            "veil_color",
            "ring_number",
            "ring_type",
            "spore_print_color",
            "population",
            "habitat")

colnames(mushrooms) <- fields

set.seed(1023)
train_idx <- createDataPartition(mushrooms$class, p=0.6, list=FALSE)
trControl <- trainControl(method = "repeatedcv",  number=10, repeats=5, verboseIter=TRUE)

frm <- paste("class ~ ", paste(relevant_features, collapse="+"))
frm

rpart.grid <- expand.grid(.cp=0)

rpart_fit <- train(frm, 
                     data = mushrooms[train_idx,], 
                     method ="rpart", 
                     trControl = trControl,
                     tuneGrid=rpart.grid,
                     metric = 'Accuracy') 

1 个答案:

答案 0 :(得分:1)

您的代码中最大的问题是frm调用。正如@Roland所说,你需要做as.formula。以下代码有效。我还包括一个删除零方差列的部分,因为在设置对比时使用带有公式调用的火车时会出现这种情况。

library(readxl)
mushrooms <- read_excel("Mushroom.xlsx")

fields <- c("class",
            "cap_shape",
            "cap_surface",
            "cap_color",
            "bruises",
            "odor",
            "gill_attachment",
            "gill_spacing",
            "gill_size",
            "gill_color",
            "stalk_shape",
            "stalk_root",
            "stalk_surface_above_ring",
            "stalk_surface_below_ring",
            "stalk_color_above_ring",
            "stalk_color_below_ring",
            "veil_type",
            "veil_color",
            "ring_number",
            "ring_type",
            "spore_print_color",
            "population",
            "habitat")

colnames(mushrooms) <- fields

library(caret)
library(rpart)

y <- "class"
cols_to_remove <- names(mushrooms)[nearZeroVar(mushrooms)]

#[1] "gill_attachment" "veil_type"       "veil_color"  

relevant_features <- setdiff(names(mushrooms), c(y, cols_to_remove))

mushrooms$class <- as.factor(mushrooms$class)

set.seed(1023)
train_idx <- createDataPartition(mushrooms$class, p=0.6, list=FALSE)
trControl <- trainControl(method = "repeatedcv",  number=10, repeats=5, verboseIter=TRUE)

frm <- as.formula(paste("class ~ ", paste(relevant_features, collapse="+")))

rpart.grid <- expand.grid(.cp=0)

rpart_fit <- train(frm, 
                   data = mushrooms[train_idx, ], 
                   method ="rpart", 
                   trControl = trControl,
                   tuneGrid=rpart.grid,
                   metric = "Accuracy")