我正在试图弄清楚如何在ggplot2中为拟合的svm模型绘制决策边界。现在,我正在尝试使用stat_contour这样做。这是我的代码,最后调用了我的函数。您可以在我的github页面上找到我正在使用的数据文件:
train <- read.table('train.txt', col.names = c('digit', 'intensity', 'symmetry'))
test <- read.table('test.txt', col.names = c('digit', 'intensity', 'symmetry'))
digits.SVM <- function(train, test, digits = c(1, 5), C = 0.01, kernel = 'radial', degree = 3, gamma = 1, coef0 = 0, scale = FALSE, type = 'C-classification', plotApproximation = FALSE) {
library(e1071)
library(ggplot2)
library(reshape2)
if(length(digits) != 1 && length(digits) != 2)
stop('Invalid length of digits vector. Must specify one or two digits to classify')
if(length(digits) == 2) {
train <- train[(train$digit == digits[1]) | (train$digit == digits[2]), ]
test <- test[(test$digit == digits[1]) | (test$digit == digits[2]), ]
}
train$class <- -1
test$class <- -1
train[train$digit == digits[1], ]$class <- 1
test[test$digit == digits[1], ]$class <- 1
fit <- svm(class~intensity + symmetry, data = train, cost = C, kernel = kernel, degree = degree, gamma = gamma, coef0 = coef0, scale = scale, type = type)
class_fitted <- predict(fit, train[c('intensity', 'symmetry')])
gridRange <- apply(train[c('intensity', 'symmetry')], 2, range)
x1 <- seq(from = gridRange[1, 1] - 0.025, to = gridRange[2, 1] + 0.025, length = 75)
x2 <- seq(from = gridRange[1, 2] - 0.05, to = gridRange[2, 2] + 0.05, length = 75)
grid <- expand.grid(intensity = x1, symmetry = x2)
grid$class <- predict(fit, grid)
decisionValues <- predict(fit, grid, decision.values = TRUE)
grid$z <- as.vector(attributes(decisionValues)$decision)
print(colnames(grid))
print(head(grid))
p <- ggplot(data = grid, aes(intensity, symmetry, colour = as.factor(class))) +
geom_point(size = 1.5) +
scale_fill_manual(values = c('red', 'black')) +
stat_contour(data = grid, aes(x = intensity, y = symmetry, z = z), breaks = c(0)) +
geom_point(data = train, aes(intensity, symmetry, colour = as.factor(class)), alpha = 0.7) +
scale_colour_manual(values = c('red', 'black')) + labs(colour = 'Class') +
scale_x_continuous(expand = c(0,0)) +
scale_y_continuous(expand = c(0,0))
print(p)
mean(train$class != class_fitted)
}
digits.SVM(train, test, digits = c(0), kernel = 'polynomial', degree = 2, coef0 = 1)
在stat_contour()中设置“break”选项时出现问题。我设定的大多数值都不会引起任何问题;这是break = -1时产生的图。
然而,正确的边界对应于设置中断= 0时产生的轮廓,并且当我设置中断接近0时,ggplot开始绘制轮廓图。它开始切断,正好为0,它根本不绘制轮廓。
以下是break = -0.05的图表示例:
如您所见,边界开始切断。现在这里是使用breaks = 0的图:
整个轮廓已被切除。
我也收到此错误消息:
警告讯息: 1:无法生成轮廓数据
我对ggplot2比较陌生,我不确定stat_contour()在后台做什么。我试图寻找它的实现,但没有运气。任何帮助和澄清将不胜感激!
我也欢迎任何有关更好方法的建议。
答案 0 :(得分:1)
我成功地绘制了轮廓图,但是,我使用R基本图形代替ggplot2。我仍然有兴趣学习如何使用ggplot2创建类似的图。
这是我更新的代码和一些示例图:
train <- read.table('train.txt', col.names = c('digit', 'intensity', 'symmetry'))
test <- read.table('test.txt', col.names = c('digit', 'intensity', 'symmetry'))
digits.SVM <- function(train, test, digits = c(1, 5), C = 0.01, kernel = 'radial', degree = 3, gamma = 1, coef0 = 0, scale = FALSE, type = 'C-classification', classification.plot = FALSE) {
library(e1071)
if(length(digits) != 1 && length(digits) != 2)
stop('Invalid length of digits vector. Must specify one or two digits to classify')
if(length(digits) == 2) {
train <- train[(train$digit == digits[1]) | (train$digit == digits[2]), ]
test <- test[(test$digit == digits[1]) | (test$digit == digits[2]), ]
}
train$class <- -1
test$class <- -1
train[train$digit == digits[1], ]$class <- 1
test[test$digit == digits[1], ]$class <- 1
fit <- svm(class~intensity + symmetry, data = train, cost = C, kernel = kernel, degree = degree, gamma = gamma, coef0 = coef0, scale = scale, type = type)
train$fitted <- predict(fit, train[c('intensity', 'symmetry')])
test$fitted <- predict(fit, test[c('intensity', 'symmetry')])
if(classification.plot) {
gridRange <- apply(train[c('intensity', 'symmetry')], 2, range)
x1 <- seq(from = gridRange[1, 1] - 0.025, to = gridRange[2, 1] + 0.025, length = 75)
x2 <- seq(from = gridRange[1, 2] - 0.05, to = gridRange[2, 2] + 0.05, length = 75)
grid <- expand.grid(intensity = x1, symmetry = x2)
grid$class <- predict(fit, grid)
decisionValues <- predict(fit, grid, decision.values = TRUE)
grid$z <- as.vector(attributes(decisionValues)$decision)
## TESTING PURPOSES
# print(range(grid$z))
# print(sum(train$fitted == -1))
# print(length(train$fitted))
## GGPLOT VERSION OF PLOT; CONTOUR NEEDS DEBUGGING
# library(ggplot2)
# p <- ggplot(data = grid, aes(intensity, symmetry, colour = as.factor(class))) +
# geom_point(size = 1.5) +
# scale_fill_manual(values = c('red', 'black')) +
# stat_contour(data = grid, aes(x = intensity, y = symmetry, z = z), breaks = c(0)) +
# geom_point(data = train, aes(intensity, symmetry, colour = as.factor(class)), alpha = 0.7) +
# scale_colour_manual(values = c('red', 'black')) + labs(colour = 'Class') +
# scale_x_continuous(expand = c(0,0)) +
# scale_y_continuous(expand = c(0,0))
# print(p)
par(mfrow = c(1,2))
## Note: RGB Specification seems to increase running and plotting time complexity
plot(grid[c('intensity', 'symmetry')], col = ifelse(grid$class == 1, '#0571B070', '#CA002070'), main = 'Training', pch='20', cex=.2)
points(train[c('intensity', 'symmetry')], col = ifelse(train$class == 1, '#0571B070', '#CA002070'))
contour(x1, x2, matrix(grid$z, length(x1), length(x2)), level=0, lwd = 1.5, drawlabels = FALSE, add=TRUE)
plot(grid[c('intensity', 'symmetry')], col = ifelse(grid$class == 1, '#0571B070', '#CA002070'), main = 'Test', pch='20', cex=.2)
points(test[c('intensity', 'symmetry')], col = ifelse(test$class == 1, '#0571B070', '#CA002070'))
contour(x1, x2, matrix(grid$z, length(x1), length(x2)), level=0, lwd = 1.5, drawlabels = FALSE, add=TRUE)
mtext(paste('Digit Classification Plots:', digits[1], 'vs', ifelse(length(digits) == 2, digits[2], 'All'), '\nKernel:', kernel, '\nC:', C), line = -3, outer = TRUE)
}
list(E_in = mean(train$fitted != train$class), E_out = mean(test$fitted != test$class), num_support_vectors = nrow(fit$SV))
}
digits.SVM(train, test, digits = c(0), kernel = 'polynomial', degree = 2, coef0 = 1, classification.plot = TRUE)
digits.SVM(train, test, digits = c(1, 5), kernel = 'radial', gamma = 1, C = 10^6, classification.plot = TRUE)