我一直在研究Rafael A. Irizarry的书Introduction to data science
,并且不断遇到我想重新创建的决策边界图(下图中的左图)
我找到了一个代码,可以在https://michael.hahsler.net/SMU/EMIS7332/R/viz_classifier.html上创建决策边界图,但是可以完成工作,但是这些图看起来并不像书中的那样。
library(randomForest)
library(tidyverse)
library(caret)
library(dslabs)
decisionplot <- function(model, data, class = NULL, predict_type = "class",
resolution = 100, showgrid = TRUE, ...) {
if(!is.null(class)) cl <- data[,class] else cl <- 1
data <- data[,1:2]
k <- length(unique(cl))
plot(data, col = as.integer(cl)+1L, pch = as.integer(cl)+1L, ...)
# make grid
r <- sapply(data, range, na.rm = TRUE)
xs <- seq(r[1,1], r[2,1], length.out = resolution)
ys <- seq(r[1,2], r[2,2], length.out = resolution)
g <- cbind(rep(xs, each=resolution), rep(ys, time = resolution))
colnames(g) <- colnames(r)
g <- as.data.frame(g)
### guess how to get class labels from predict
### (unfortunately not very consistent between models)
p <- predict(model, g, type = predict_type)
if(is.list(p)) p <- p$class
p <- as.factor(p)
if(showgrid) points(g, col = as.integer(p)+1L, pch = ".")
z <- matrix(as.integer(p), nrow = resolution, byrow = TRUE)
contour(xs, ys, z, add = TRUE, drawlabels = FALSE,
lwd = 2, levels = (1:(k-1))+.5)
invisible(z)
}
train_rf<- randomForest(y~., data = mnist_27$train)
decisionplot(train_rf, data= mnist_27$train %>% select(x_1, x_2, y) , class="y")
我需要帮助来像书中那样绘制决策边界图。
答案 0 :(得分:1)
感谢尼尔森。看到您的链接和其他一些资源,就可以了。
library(randomForest)
library(tidyverse)
library(caret)
library(dslabs)
library(ggthemes)
model<- randomForest(y~., data = mnist_27$train)
data<- mnist_27$train %>% select(x_1, x_2, y)
class<- "y"
#predict_type = "class"
resolution = 75
if(!is.null(class)) cl <- data[,class] else cl <- 1
data <- data[,1:2]
r <- sapply(data, range, na.rm = TRUE)
xs <- seq(r[1,1], r[2,1], length.out = resolution)
ys <- seq(r[1,2], r[2,2], length.out = resolution)
g <- cbind(rep(xs, each=resolution), rep(ys, time = resolution))
colnames(g) <- colnames(r)
g <- as.data.frame(g)
q<- predict(model, g, type = "class")
p <- predict(model, g, type = "prob")
p<- p %>% as.data.frame() %>% mutate(p=if_else(`2`>=`7`, `2`, `7`))
p<- p %>% mutate(pred= as.integer(q))
ggplot()+
geom_raster(data= g, aes(x= x_1, y=x_2, fill=p$`2` ), interpolate = TRUE)+
geom_contour(data= NULL, aes(x= g$x_1, y=g$x_2, z= p$pred), breaks=c(1.5), color="black", size=1)+
theme_few()+
scale_colour_manual(values = cols)+
labs(colour = "", fill="")+
scale_fill_gradient2(low="#338cea", mid="white", high="#dd7e7e",
midpoint=0.5, limits=range(p$`2`))+
theme(legend.position = "none")