决策树与R

时间:2014-06-30 10:06:18

标签: r machine-learning decision-tree

我从rpart-manpage

运行了这个例子
tree <- rpart(Species~., data = iris)
plot(tree,margin=0.1)
text(tree)

现在我想修改它,用于另一个数据集

digitstrainURL <- "http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tra"
digitsTestURL <- "http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tes"
digitstrain <- read.table(digitstrainURL, sep=",",
                          col.names=c("i1","i2","i3","i4","i5","i6","i7","i8","i9","i10","i11","i12","i13","i14","i15","i16", "Class"))
digitstest <- read.table(digitsTestURL, sep=",",
col.names=c("i1","i2","i3","i4","i5","i6","i7","i8","i9","i10","i11","i12","i13","i14","i15","i16", "Class"))

tree <- rpart(Class~., data = digitstrain)
plot(tree,margin=0.1)
text(tree)

数据集包含手写数字的数据,“Class”包含数字0-9 但是当我绘制树时,我会得到有效的浮点数,任何想法这些数字是什么意思?我希望将0-9作为文本的叶子。

2 个答案:

答案 0 :(得分:1)

您正在尝试拟合分类树,但您的数据是整数,而不是因素。

函数rpart将尝试猜测使用哪种方法,并在您的情况下做出错误的猜测。因此,您的代码适合基于method="anova"的树,而您希望使用method="class"

试试这个:

tree <- rpart(Class~., data = digitstrain, method="class")
plot(tree,margin=0.1)
text(tree, cex=0.7)

enter image description here

要测试模型的准确性,您可以使用predict获取预测值,然后创建混淆矩阵:

confusion <- data.frame(
  class=factor(digitstest$Class), 
  predict=predict(tree, digitstest, type="class")
  )
with(confusion, table(class, predict))

     predict
class   0   1   2   3   4   5   6   7   8   9
    0 311   1   0   0   0   0   0   7  42   2
    1   0 139 186   4   0   0   0   1  10  24
    2   0   0 320  14   2   3   0   7  15   3
    3   0   6   0 309   1   3   0  17   0   0
    4   0   1   0   5 300   0   0   0   0  58
    5   0   0   0  74   0 177   0   1  14  69
    6   5   0   3   9  12   0 264  11   5  27
    7   2   9  11  13   0  10   0 290   0  29
    8  60   0   0   0   0  32   0  21 220   3
    9   1  44   0   9  20   0   0   8   0 254

请注意,使用单个树的预测并不是很好。一种改进预测的一种非常简单的方法是使用随机森林,其中包含许多符合训练数据随机子集的树木:

library(randomForest)

fst <- randomForest(factor(Class)~., data = digitstrain, method="class")

观察森林给出了更好的预测结果:

confusion <- data.frame(
  class=factor(digitstest$Class), 
  predict=predict(fst, digitstest, type="class")
  )
with(confusion, table(class, predict))

     predict
class   0   1   2   3   4   5   6   7   8   9
    0 347   0   0   0   0   0   0   0  16   0
    1   0 333  28   1   1   0   0   1   0   0
    2   0   5 359   0   0   0   0   0   0   0
    3   0   4   0 331   0   0   0   0   0   1
    4   0   0   0   0 362   1   0   0   0   1
    5   0   0   0   8   0 316   0   0   0  11
    6   1   0   0   0   0   0 335   0   0   0
    7   0  26   2   0   0   0   0 328   0   8
    8   0   0   0   0   0   0   0   0 336   0
    9   0   2   0   0   0   0   0   2   1 331

答案 1 :(得分:0)

这种情况正在发生,因为您的Class列是数字。将其转换为因子然后尝试...

digitstrain$Class = as.factor(digitstrain$Class)
tree <- rpart(Class~., data = digitstrain)
plot(tree,margin=0.1)
text(tree)

结果将是

enter image description here