我正在通过EdEx完成哈佛R课程;我要讲的是机器学习模块,涉及knn。我使用mnist_27训练数据创建了knn拟合,然后使用预测函数来确定结果是数字2还是7。使用ggplot,我已基于网格(x_1)上的像素绘制了预测点(y)。和x_2);然后我用y上色。我现在想做的是使用stat_contour在p = 0.5边界处放置轮廓。但是,出现此错误:
在
stat_contour()
中计算失败:等高线在z
和x
的每种组合中需要单个y
。
library(tidyverse)
library(caret)
library(dslabs)
data("mnist_27")
knn_fit <- knn3(y ~ ., data = mnist_27$train, k = 5)
x_1 <- mnist_27$train$x_1
x_2 <- mnist_27$train$x_2
y_x <- predict(knn_fit, mnist_27$train, type = "class")
p_hat_knn <- predict(knn_fit, mnist_27$train, type = "prob")
p_x <- p_hat_knn[,2]
knn_df <- data.frame(x_1, x_2, p_x, y_x)
plot_val <- knn_df %>%
ggplot() +
geom_point(aes(x = x_1, y = x_2, colour = factor(y_x)), shape=21, size=2, stroke=1) +
stat_contour(aes(x = x_1, y = x_2, z=p_x), breaks=c(0.5), color="black")
plot(plot_val)
该错误告诉我,每个(x_1,x_2)对的轮廓都没有概率,但是我的数据框每行都有一个p_x,所以我不确定出什么问题了。如果有人可以提供帮助,将不胜感激。
答案 0 :(得分:0)
我不太清楚为什么,但是我认为stat_contour失败的原因是由于对(x_1,x_2)以及p的观察不足。
我没有使用火车数据集(800个观测值)中的(x_1,x_2),而是使用了完整的(mnist $ true_p),它具有22500个观测值。我重新编码为使用(mnist $ true_p $ x_1,mnist $ true_p $ x_2)从合适的位置获取p_x。使用相同的代码,stat_contour然后起作用。
k_val <- 1
knn_fit <- knn3(y ~ ., data = mnist_27$train, k = k_val)
x_1 <- mnist_27$true_p$x_1
x_2 <- mnist_27$true_p$x_2
knn_df <- data.frame(x_1, x_2)
y_x <- predict(knn_fit, knn_df, type = "class")
p_hat_knn <- predict(knn_fit, knn_df, type = "prob")
p_x <- p_hat_knn[,2]
knn_df <- data.frame(x_1, x_2, p_x)
p1 <- ggplot() +
geom_point(data=mnist_27$train, aes(x = x_1, y = x_2, colour = factor(y)), shape=21, size=2, stroke=1) +
stat_contour(data=knn_df, aes(x=x_1, y=x_2, z=p_x), breaks=c(0.5), color="black")
plot(p1)
knn_fit <- knn3(y ~ ., data = mnist_27$test, k = k_val)
x_1 <- mnist_27$true_p$x_1
x_2 <- mnist_27$true_p$x_2
knn_df <- data.frame(x_1, x_2)
y_x <- predict(knn_fit, knn_df, type = "class")
p_hat_knn <- predict(knn_fit, knn_df, type = "prob")
p_x <- p_hat_knn[,2]
knn_df <- data.frame(x_1, x_2, p_x)
p2 <- ggplot() +
geom_point(data=mnist_27$test, aes(x = x_1, y = x_2, colour = factor(y)), shape=21, size=2, stroke=1) +
stat_contour(data=knn_df, aes(x=x_1, y=x_2, z=p_x), breaks=c(0.5), color="black")
plot(p2)
grid.arrange(p1, p2, nrow=1)
代码现在产生了我所需要的。
如果有人有其他建议,请告诉我。谢谢。