我从rpart-manpage
运行了这个例子tree <- rpart(Species~., data = iris)
plot(tree,margin=0.1)
text(tree)
现在我想修改它,用于另一个数据集
digitstrainURL <- "http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tra"
digitsTestURL <- "http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tes"
digitstrain <- read.table(digitstrainURL, sep=",",
col.names=c("i1","i2","i3","i4","i5","i6","i7","i8","i9","i10","i11","i12","i13","i14","i15","i16", "Class"))
digitstest <- read.table(digitsTestURL, sep=",",
col.names=c("i1","i2","i3","i4","i5","i6","i7","i8","i9","i10","i11","i12","i13","i14","i15","i16", "Class"))
tree <- rpart(Class~., data = digitstrain)
plot(tree,margin=0.1)
text(tree)
数据集包含手写数字的数据,“Class”包含数字0-9 但是当我绘制树时,我会得到有效的浮点数,任何想法这些数字是什么意思?我希望将0-9作为文本的叶子。
答案 0 :(得分:1)
您正在尝试拟合分类树,但您的数据是整数,而不是因素。
函数rpart
将尝试猜测使用哪种方法,并在您的情况下做出错误的猜测。因此,您的代码适合基于method="anova"
的树,而您希望使用method="class"
。
试试这个:
tree <- rpart(Class~., data = digitstrain, method="class")
plot(tree,margin=0.1)
text(tree, cex=0.7)
要测试模型的准确性,您可以使用predict
获取预测值,然后创建混淆矩阵:
confusion <- data.frame(
class=factor(digitstest$Class),
predict=predict(tree, digitstest, type="class")
)
with(confusion, table(class, predict))
predict
class 0 1 2 3 4 5 6 7 8 9
0 311 1 0 0 0 0 0 7 42 2
1 0 139 186 4 0 0 0 1 10 24
2 0 0 320 14 2 3 0 7 15 3
3 0 6 0 309 1 3 0 17 0 0
4 0 1 0 5 300 0 0 0 0 58
5 0 0 0 74 0 177 0 1 14 69
6 5 0 3 9 12 0 264 11 5 27
7 2 9 11 13 0 10 0 290 0 29
8 60 0 0 0 0 32 0 21 220 3
9 1 44 0 9 20 0 0 8 0 254
请注意,使用单个树的预测并不是很好。一种改进预测的一种非常简单的方法是使用随机森林,其中包含许多符合训练数据随机子集的树木:
library(randomForest)
fst <- randomForest(factor(Class)~., data = digitstrain, method="class")
观察森林给出了更好的预测结果:
confusion <- data.frame(
class=factor(digitstest$Class),
predict=predict(fst, digitstest, type="class")
)
with(confusion, table(class, predict))
predict
class 0 1 2 3 4 5 6 7 8 9
0 347 0 0 0 0 0 0 0 16 0
1 0 333 28 1 1 0 0 1 0 0
2 0 5 359 0 0 0 0 0 0 0
3 0 4 0 331 0 0 0 0 0 1
4 0 0 0 0 362 1 0 0 0 1
5 0 0 0 8 0 316 0 0 0 11
6 1 0 0 0 0 0 335 0 0 0
7 0 26 2 0 0 0 0 328 0 8
8 0 0 0 0 0 0 0 0 336 0
9 0 2 0 0 0 0 0 2 1 331
答案 1 :(得分:0)
这种情况正在发生,因为您的Class列是数字。将其转换为因子然后尝试...
digitstrain$Class = as.factor(digitstrain$Class)
tree <- rpart(Class~., data = digitstrain)
plot(tree,margin=0.1)
text(tree)
结果将是