测试由Rpart包生成的规则

时间:2012-08-06 15:58:04

标签: r machine-learning rpart

我想以编程方式测试从树生成的一个规则。在树中,根和叶子(终端节点)之间的路径可以解释为规则。

在R中,我们可以使用rpart包并执行以下操作: (在这篇文章中,我将使用iris数据集,仅用于示例目的)

library(rpart)
model <- rpart(Species ~ ., data=iris)

通过这两行,我得到了一个名为model的树,其类为rpart.objectrpart文档,第21页)。这个对象有很多信息,并且支持多种方法。特别是,该对象具有frame变量(可以以标准方式访问:model$frame)( idem )和方法path.rpath({{ 1}}文档,第7页),它给出了从根节点到感兴趣的节点的路径(函数中的rpart参数)

node变量的row.names包含树的节点编号。 frame列为节点中的split变量var提供拟合值和yval类概率以及其他信息。

yval2

但只有> model$frame var n wt dev yval complexity ncompete nsurrogate yval2.1 yval2.2 yval2.3 yval2.4 yval2.5 yval2.6 yval2.7 1 Petal.Length 150 150 100 1 0.50 3 3 1.00000000 50.00000000 50.00000000 50.00000000 0.33333333 0.33333333 0.33333333 2 <leaf> 50 50 0 1 0.01 0 0 1.00000000 50.00000000 0.00000000 0.00000000 1.00000000 0.00000000 0.00000000 3 Petal.Width 100 100 50 2 0.44 3 3 2.00000000 0.00000000 50.00000000 50.00000000 0.00000000 0.50000000 0.50000000 6 <leaf> 54 54 5 2 0.00 0 0 2.00000000 0.00000000 49.00000000 5.00000000 0.00000000 0.90740741 0.09259259 7 <leaf> 46 46 1 3 0.01 0 0 3.00000000 0.00000000 1.00000000 45.00000000 0.00000000 0.02173913 0.97826087 列中标记为<leaf>的是终端节点( leafs )。在这种情况下,节点是2,6和7。

如上所述,您可以使用var方法提取规则(path.rpart包和文章Sharma Credit Score中使用此方法,如下所示:

此外,模型将预测值的值保存在

rattle

此值与predicted.levels <- attr(model, "ylevels") 数据集中的yval列相对应。

对于节点号为7(行号5)的叶子,预测值为

model$frame

,规则是

> ylevels[model$frame[5, ]$yval]
[1] "virginica"

因此,该规则可以理解为

> rule <- path.rpart(model, nodes = 7)

 node number: 7 
   root
   Petal.Length>=2.45
   Petal.Width>=1.75

我知道我可以测试(在测试数据集中,我将再次使用虹膜数据集)我对此规则有多少真正的积极因素,将新数据集分组如下

If Petal.Length >= 2.45 AND Petal.Width >= 1.75 THEN Species = Virginica

然后计算混淆矩阵

> hits <- subset(iris, Petal.Length >= 2.45 & Petal.Width >= 1.75)

(注意:我使用相同的虹膜数据集作为测试)

如何以编程方式评估规则?我可以从规则中提取条件如下

> table(hits$Species, hits$Species == "virginica")

             FALSE TRUE
  setosa         0    0
  versicolor     1    0
  virginica      0   45

但是,我怎么能从这里继续?我无法使用> unlist(rule, use.names = FALSE)[-1] [1] "Petal.Length>=2.45" "Petal.Width>=1.75" 函数

提前致谢

注意: 此问题经过大量编辑,以便更清晰

3 个答案:

答案 0 :(得分:3)

我可以通过以下方式解决这个问题

免责声明:显然必须有更好的方法来解决这个问题,但是这种黑客攻击行为并做我想做的事情......(我并不为此感到骄傲......是hackish,但是有效)

好的,让我们开始吧。基本上这个想法是使用包sqldf

如果您检查问题,最后一段代码会在树的每个路径中放入一个列表。所以,我将从那里开始

        library(sqldf)
        library(stringr)

        # Transform to a character vector
        rule.v <- unlist(rule, use.names=FALSE)[-1]
        # Remove all the dots, sqldf doesn't handles dots in names 
        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2")
        # We have to remove all the equal signs to 'in ('
        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('")
        # Embrace all the elements in the lists of values with " ' " 
        # The last element couldn't be modified in this way (Any ideas?) 
        rule.v <- str_replace_all(rule.v, pattern=",", replacement="','")

        # Close the last element with apostrophe and a ")" 
        for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) {
          rule.v[i] <- paste(append(rule.v[i], "')"), collapse="")
        }

        # Collapse all the list in one string joined by " AND "
        rule.v <- paste(rule.v, collapse = " AND ")

        # Generate the query
        # Use any metric that you can get from the data frame
        query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="")

        # For debug only...
        print(query)

        # Execute and print the results
        print(sqldf(query))

就是这样!

我警告过你,这是黑客......

希望这有助于其他人...

感谢所有的帮助和建议!

答案 1 :(得分:2)

一般情况下,我不建议使用eval(parse(...)),但在这种情况下,它似乎有效:

提取规则:

rule <- unname(unlist(path.rpart(model, nodes=7)))[-1]

 node number: 7 
   root
   Petal.Length>=2.45
   Petal.Width>=1.75
rule
[1] "Petal.Length>=2.45" "Petal.Width>=1.75" 

使用以下规则提取数据:

node_data <- with(iris, iris[eval(parse(text=paste(rule, collapse=" & "))), ])
head(node_data)

    Sepal.Length Sepal.Width Petal.Length Petal.Width    Species
71           5.9         3.2          4.8         1.8 versicolor
101          6.3         3.3          6.0         2.5  virginica
102          5.8         2.7          5.1         1.9  virginica
103          7.1         3.0          5.9         2.1  virginica
104          6.3         2.9          5.6         1.8  virginica
105          6.5         3.0          5.8         2.2  virginica

答案 2 :(得分:1)

开始
Rule number: 16 [yval=bad cover=220 N=121 Y=99 (37%) prob=0.04]
checking< 2.5
afford< 54
history< 3.5
coapp< 2.5

你会有一个'prob'向量,全部为零,你可以用rule16更新:

prob <- ifelse( dat[['checking']] < 2.5 &
                dat[['afford']]  < 54
                dat[['history']] < 3.5
                dat[['coapp']]  < 2.5) , 0.04, prob )

然后,您需要运行所有其他规则(不应该更改此情况的任何概率,因为树应该是不相交的估计值。)构建预测可能有比这更有效的方法。例如...... predict.rpart函数。