这是场景:
library(rpart); library(dplyr); library(caret)
data <- read.csv("NetworkIntrusionValidatedata.csv") #50 example rows provided below
traindata <- createDataPartition(y,p=0.9,list = F) %>% c()
train <- data[traindata,]
test <- data[-traindata,]
trainy <- y[traindata]
testy <- y[-traindata]
train <- cbind(train,trainy)
model2 <- rpart(trainy~., data = train)
prunedmodel <- prune(model2, cp = 0.27)
test2pred <- predict(model2, newdata = test, type = "prob")
test2pred <- factor(ifelse(test2pred>0.7,"normal","anomaly"))
table(test2pred)
结果是:
test2pred
anomaly normal
2254 2254
但是“测试”数据中总共只有2254个观测值。异常和正常情况下如何有2254个值?我在相同的数据上运行了一个“ bayesglm”,它可以正常工作。
样本数据:
structure(list(duration = c(0L, 0L, 2L, 0L, 1L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 37L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 805L, 0L, 0L, 0L, 0L, 0L, 0L, 8L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L), protocol_type = structure(c(2L, 2L, 2L, 1L,
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 3L, 2L, 2L,
2L, 2L, 2L, 2L, 2L, 2L, 2L, 1L, 3L, 2L, 2L, 2L, 3L, 2L, 2L, 2L, 2L,
3L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 3L), .Label = c("icmp",
"tcp", "udp"), class = "factor"), service = structure(c(46L, 46L,
20L, 14L, 56L, 23L, 50L, 56L, 23L, 19L, 56L, 50L, 46L, 56L, 56L, 23L,
23L, 23L, 46L, 46L, 29L, 44L, 23L, 23L, 6L, 10L, 23L, 23L, 15L, 46L,
23L, 50L, 23L, 46L, 46L, 25L, 23L, 19L, 12L, 20L, 46L, 23L, 23L, 23L,
32L, 23L, 55L, 23L, 56L, 46L), .Label = c("IRC", "X11", "Z39_50",
"auth", "bgp", "courier", "csnet_ns", "ctf", "daytime", "discard",
"domain", "domain_u", "echo", "eco_i", "ecr_i", "efs", "exec",
"finger", "ftp", "ftp_data", "gopher", "hostnames", "http",
"http_443", "imap4", "iso_tsap", "klogin", "kshell", "ldap", "link",
"login", "mtp", "name", "netbios_dgm", "netbios_ns", "netbios_ssn",
"netstat", "nnsp", "nntp", "ntp_u", "other", "pm_dump", "pop_2",
"pop_3", "printer", "private", "remote_job", "rje", "shell", "smtp",
"sql_net", "ssh", "sunrpc", "supdup", "systat", "telnet", "tftp_u",
"tim_i", "time", "urp_i", "uucp", "uucp_path", "vmnet", "whois"),
class = "factor"), flag = structure(c(2L, 2L, 10L, 10L, 3L, 10L, 10L,
10L, 10L, 10L, 10L, 10L, 2L, 6L, 10L, 10L, 10L, 10L, 10L, 2L, 2L, 6L,
10L, 10L, 2L, 3L, 10L, 10L, 10L, 10L, 5L, 10L, 10L, 10L, 2L, 3L, 10L,
10L, 10L, 10L, 6L, 10L, 10L, 10L, 2L, 10L, 6L, 10L, 6L, 10L), .Label
= c("OTH", "REJ", "RSTO", "RSTOS0", "RSTR", "S0", "S1", "S2", "S3", "SF", "SH"), class = "factor"), src_bytes = c(0L, 0L, 12983L, 20L,
0L, 267L, 1022L, 129L, 327L, 26L, 0L, 616L, 0L, 0L, 773L, 350L, 213L,
246L, 45L, 0L, 0L, 0L, 196L, 277L, 0L, 0L, 294L, 300L, 520L, 54L,
76944L, 720L, 301L, 1L, 0L, 0L, 209L, 220L, 43L, 88382L, 0L, 277L,
321L, 335L, 0L, 234L, 0L, 54540L, 0L, 46L), dst_bytes = c(0L, 0L, 0L,
0L, 15L, 14515L, 387L, 174L, 467L, 157L, 0L, 330L, 0L, 0L, 364200L,
3610L, 659L, 2090L, 44L, 0L, 0L, 0L, 1823L, 1816L, 0L, 0L, 6442L,
440L, 0L, 51L, 1L, 281L, 19794L, 1L, 0L, 44L, 12894L, 688L, 71L, 0L,
0L, 4968L, 2715L, 3228L, 0L, 3236L, 0L, 8314L, 0L, 45L), land = c(0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L),
wrong_fragment = c(0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L ), urgent = c(0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L), hot = c(0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 4L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 2L, 0L, 0L), num_failed_logins = c(0L, 0L, 0L, 0L,
0L, 0L, 0L, 1L, 0L, 1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L), logged_in = c(0L,
0L, 0L, 0L, 0L, 1L, 1L, 0L, 1L, 0L, 0L, 1L, 0L, 0L, 1L, 1L, 1L, 1L,
0L, 0L, 0L, 0L, 1L, 1L, 0L, 0L, 1L, 1L, 0L, 0L, 1L, 1L, 1L, 0L, 0L,
0L, 1L, 1L, 0L, 0L, 0L, 1L, 1L, 1L, 0L, 1L, 0L, 1L, 0L, 0L),
num_compromised = c(0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 1L, 0L, 0L),
root_shell = c(0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L), su_attempted = c(0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L), num_root = c(0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L), num_file_creations = c(0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 4L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L), num_shells = c(0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L), num_access_files = c(0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L), num_outbound_cmds = c(0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L
), is_host_login = c(0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L), is_guest_login = c(0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L), count = c(229L, 136L, 1L, 1L, 1L, 4L, 1L,
1L, 33L, 1L, 1L, 1L, 111L, 120L, 1L, 8L, 24L, 16L, 505L,
204L, 118L, 1L, 17L, 17L, 116L, 273L, 22L, 7L, 511L, 511L,
12L, 1L, 15L, 40L, 483L, 2L, 11L, 1L, 113L, 15L, 144L, 13L,
29L, 49L, 266L, 8L, 281L, 4L, 1L, 68L), srv_count = c(10L,
1L, 1L, 65L, 8L, 4L, 3L, 1L, 47L, 1L, 1L, 2L, 2L, 120L, 1L,
8L, 24L, 16L, 505L, 18L, 19L, 1L, 17L, 18L, 8L, 13L, 46L,
7L, 511L, 511L, 12L, 2L, 15L, 3L, 1L, 2L, 11L, 1L, 113L,
15L, 8L, 13L, 37L, 50L, 8L, 21L, 6L, 24L, 12L, 68L), serror_rate = c(0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.05, 0, 0, 0,
0, 0, 1, 0, 0.03, 0, 0, 0, 1, 0, 1, 0), srv_serror_rate = c(0,
0, 0, 0, 0.12, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0.05, 0, 0, 0, 1, 0, 0.33, 0), rerror_rate = c(1,
1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,
1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0.92, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0), srv_rerror_rate = c(1,
1, 0, 0, 0.5, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,
1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 0.67, 0), same_srv_rate = c(0.04,
0.01, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.02, 1, 1, 1, 1, 1,
1, 0.09, 0.16, 1, 1, 1, 0.07, 0.05, 1, 1, 1, 1, 1, 1, 1,
0.08, 0, 0.5, 1, 1, 1, 1, 0.06, 1, 1, 1, 0.03, 1, 0.02, 1,
1, 1), diff_srv_rate = c(0.06, 0.06, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0.07, 0, 0, 0, 0, 0, 0, 0.07, 0.05, 0, 0, 0, 0.07,
0.06, 0, 0, 0, 0, 0, 0, 0, 0.38, 1, 1, 0, 0, 0, 0, 0.06,
0, 0, 0, 0.06, 0, 0.06, 0, 0, 0), srv_diff_host_rate = c(0,
0, 0, 1, 0.75, 0, 1, 0, 0.04, 0, 0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0.11, 0, 0, 0.11, 0, 0, 0, 0, 1, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0.08, 0.04, 0, 0.1, 0, 0.08, 0.92, 0
), dst_host_count = c(255L, 255L, 134L, 3L, 29L, 155L, 255L,
255L, 151L, 52L, 255L, 255L, 255L, 235L, 38L, 71L, 255L,
35L, 255L, 255L, 255L, 255L, 255L, 36L, 255L, 255L, 180L,
255L, 46L, 255L, 241L, 158L, 20L, 255L, 255L, 185L, 255L,
53L, 255L, 203L, 255L, 13L, 29L, 255L, 255L, 255L, 255L,
255L, 91L, 255L), dst_host_srv_count = c(10L, 1L, 86L, 57L,
86L, 255L, 28L, 255L, 255L, 26L, 128L, 129L, 2L, 171L, 73L,
255L, 255L, 255L, 255L, 18L, 19L, 87L, 255L, 255L, 8L, 13L,
255L, 255L, 59L, 255L, 238L, 82L, 255L, 3L, 1L, 59L, 255L,
27L, 254L, 114L, 8L, 255L, 255L, 255L, 8L, 255L, 6L, 250L,
86L, 255L), dst_host_same_srv_rate = c(0.04, 0, 0.61, 1,
0.31, 1, 0.11, 1, 1, 0.5, 0.5, 0.51, 0.01, 0.73, 0.16, 1,
1, 1, 1, 0.07, 0.07, 0.34, 1, 1, 0.03, 0.05, 1, 1, 1, 1,
0.99, 0.52, 1, 0.01, 0, 0.24, 1, 0.51, 1, 0.38, 0.03, 1,
1, 1, 0.03, 1, 0.02, 0.98, 0.34, 1), dst_host_diff_srv_rate = c(0.06,
0.06, 0.04, 0, 0.17, 0, 0.72, 0, 0, 0.08, 0.01, 0.03, 0.07,
0.07, 0.05, 0, 0, 0, 0, 0.07, 0.05, 0.01, 0, 0, 0.06, 0.06,
0, 0, 0, 0, 0.01, 0.06, 0, 0.58, 1, 0.03, 0, 0.08, 0.01,
0.01, 0.06, 0, 0, 0, 0.06, 0, 0.07, 0.01, 0.03, 0), dst_host_same_src_port_rate = c(0,
0, 0.61, 1, 0.03, 0.01, 0, 0, 0.01, 0.02, 0, 0, 0, 0, 0.03,
0.01, 0, 0.03, 1, 0, 0, 0.01, 0, 0.03, 0, 0, 0.01, 0, 1,
0.83, 0, 0.01, 0.05, 0.99, 0, 0.01, 0, 0.02, 0, 0.38, 0,
0.08, 0.03, 0, 0, 0, 0, 0, 0.01, 0.26), dst_host_srv_diff_host_rate = c(0,
0, 0.02, 0.28, 0.02, 0.03, 0, 0, 0.03, 0, 0, 0, 0, 0, 0.04,
0.04, 0, 0.05, 0, 0, 0, 0, 0, 0.02, 0, 0, 0.01, 0, 0.14,
0, 0, 0, 0.02, 0, 0, 0.03, 0, 0, 0, 0.02, 0, 0.01, 0.04,
0, 0, 0, 0, 0, 0.03, 0), dst_host_serror_rate = c(0, 0, 0,
0, 0, 0.01, 0, 0.01, 0, 0, 0, 0, 0, 0.69, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.01, 0.05, 0, 0, 0.01,
0, 0, 0, 0, 1, 0, 0.03, 0, 0, 0, 1, 0, 1, 0), dst_host_srv_serror_rate = c(0,
0, 0, 0, 0, 0, 0, 0.01, 0, 0, 0, 0, 0, 0.95, 0.77, 0, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0), dst_host_rerror_rate = c(1,
1, 0, 0, 0.83, 0, 0.72, 0.02, 0, 0, 0.66, 0.33, 1, 0.02,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0.07, 0,
0, 0.01, 0.96, 0.89, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0.06,
0, 0), dst_host_srv_rerror_rate = c(1, 1, 0, 0, 0.71, 0,
0.04, 0.02, 0, 0, 0.32, 0, 1, 0, 0.07, 0, 0, 0, 0, 1, 1,
0, 0, 0, 1, 1, 0, 0, 0, 0, 0.07, 0, 0, 0, 1, 0.95, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0.06, 0, 0)), row.names = c(NA, 50L), class = "data.frame")
答案 0 :(得分:0)
欢迎Stack Overflow,并感谢您阅读巡回赛。
您未包含y
的数据,因此我无法完全复制您的示例。不过,我可以解释为什么会发生这种情况以及如何解决。为此,我将使用内置的mtcars
数据集。该数据具有二进制变量。 am
指示车辆是否具有自动变速器。由于它不是一个因素,因此我将其更改为一个。然后像平常一样使用rpart
。
library(rpart)
mtcars$am = factor(mtcars$am)
model2 <- rpart(am~., data = mtcars)
test2pred <- predict(model2, newdata = mtcars, type = "prob")
然后将test2pred
更改为一个因子。我希望能够查看predict
的输出,因此在将其转换为因数时,我将使用一个新名称。
test2pred.fact
anomaly normal
32 32
请注意,mtcars数据集具有32个实例,并且像您的示例一样,每个实例产生一个“异常”和一个“正常”。让我们看看为什么。为此,我们需要返回predict
的内容。
head(test2pred)
0 1
Mazda RX4 0.1428571 0.85714286
Mazda RX4 Wag 0.1428571 0.85714286
Datsun 710 0.1428571 0.85714286
Hornet 4 Drive 0.9444444 0.05555556
Hornet Sportabout 0.9444444 0.05555556
Valiant 0.9444444 0.05555556
通过指定type = "prob"
,您的预测语句要求返回概率。是和否的概率。如果其中之一大于0.7,则另一个必须小于0.7。我认为您想做的只是看是否是的可能性大于0.7。因此,而不是
test2pred.fact <- factor(ifelse(test2pred>0.7,"normal","anomaly"))
只需用以下方法测试第一个概率:
test2pred.fact <- factor(ifelse(test2pred[,1]>0.7,"normal","anomaly"))
对于我的mtcars示例,它给出:
table(test2pred.fact)
test2pred.fact
anomaly normal
14 18
现在每个实例中的一些总计达到正确的实例数量。