如何用R中的网格绘制非线性决策边界?

时间:2014-06-05 05:58:46

标签: r plot

我在ISLRFigure 2.13)或ESL中找到了这个特定的图表。我无法猜测作者是如何在R中做到这一点的。我知道如何轻松获得橙色和蓝色点。主要的混淆是背景点和紫色线。

任何想法?Fig 2.13, ISLR

以下是一些示例代码,用于获取带有灰色网格的黄色和橙色点。如何获得紫色的任意非线性曲线,然后根据曲线为网格着色?

set.seed(pi)
points = replicate(100, runif(2))
pointsColored = ifelse(apply(points, 2, sum) <= 1, "orange", "blue")
# Confound some
pointsColored[sample.int(length(pointsColored), 10)] = "orange"
plot(x=points[1, ], y=points[2, ])
grid(nx=100, ny=100)
# Plot points over the grid.
points(x=points[1, ], y=points[2, ], col=pointsColored)

2 个答案:

答案 0 :(得分:6)

正如我在评论中指出的那样,@ chl here在stats.stackexchange.com上提供了一个解决方案。在这里,它适用于您的数据集。

library(class)
set.seed(pi)
X <- t(replicate(1000, runif(2)))
g <- ifelse(apply(X, 1, sum) <= 1, 0, 1)
xnew <- cbind(rep(seq(0, 1, length.out=50), 50),
              rep(seq(0, 1, length.out=50), each=50))
m <- knn(X, xnew, g, k=15, prob=TRUE)
prob <- attr(m, "prob")
prob <- ifelse(m=="1", prob, 1-prob)
prob15 <- matrix(prob, 50)
par(mar=rep(3, 4))
contour(unique(xnew[, 1]), unique(xnew[, 2]), prob15, levels=0.5, 
        labels="", xlab='', ylab='', axes=FALSE, lwd=2.5, asp=1)
title(xlab=expression(italic('X')[1]), ylab=expression(italic('X')[2]), 
      line=1, family='serif', cex.lab=1.5)
points(X, bg=ifelse(g==1, "#CA002070", "#0571B070"), pch=21)
gd <- expand.grid(x=unique(xnew[, 1]), y=unique(xnew[, 2]))
points(gd, pch=20, cex=0.4, col=ifelse(prob15 > 0.5, "#CA0020", "#0571B0"))
box()

decision boundary

(更新:我更改了调色板,因为蓝色/黄色/紫色的东西很可怕。)

答案 1 :(得分:1)

这是我近似的愚蠢尝试。显然,@ StetinKolassa提出的问题是有效的,而不是通过这种近似来处理。

myCurve1 = function (x)
  abs(x[[1]] * sin(x[[1]]) + x[[2]] * sin(x[[2]]))
myCurve2 = function (x)
  abs(x[[1]] * cos(x[[1]]) + x[[2]] * cos(x[[2]]))
myCurve3 = function (x)
  abs(x[[1]] * tan(x[[1]]) + x[[2]] * tan(x[[2]]))

tmp = function (myCurve, seed=99) {
  set.seed(seed)
  points = replicate(100, runif(2))
  colors = ifelse(apply(points, 2, myCurve) > 0.5, "orange", "blue")
  # Confound some
  swapInts = sample.int(length(colors), 6)
  for (i in swapInts) {
    if (colors[[i]] == "orange") {
      colors[[i]] = "blue"
    } else {
      colors[[i]] = "orange"
    }
  }
  gridPoints = seq(0, 1, 0.005)
  gridPoints = as.matrix(expand.grid(gridPoints, gridPoints))
  gridColors = vector("character", nrow(gridPoints))
  gridPch = vector("character", nrow(gridPoints))
  for (i in 1:nrow(gridPoints)) {
    val = myCurve(gridPoints[i, ])
    if (val > 0.505) {
      gridColors[[i]] = "orange"
      gridPch[[i]] = "."
    } else if (val < 0.495) {
      gridColors[[i]] = "blue"
      gridPch[[i]] = "."
    } else {
      gridColors[[i]] = "purple"
      gridPch[[i]] = "*"
    }
  }
  plot(x=gridPoints[ , 1], y=gridPoints[ , 2], col=gridColors, pch=gridPch)
  points(x=points[1, ], y=points[2, ], col=colors, lwd=2)
}

par(mfrow=c(1, 3))
tmp(myCurve1)
tmp(myCurve2)
tmp(myCurve3)

enter image description here

相关问题