使用决策规则拆分其他数据

时间:2018-08-30 15:10:05

标签: r weka decision-tree rweka

我正在寻找一种优雅的解决方案,使用在一个数据集中创建的决策规则(例如您的训练集),根据这些规则拆分另一个数据集的数据(例如测试数据)。

看这个例子:

# Load PimaIndiansDiabetes dataset from mlbench package
library("mlbench")
data("PimaIndiansDiabetes")
## Split in training and test (2/3 - 1/3)
idtrain <- c(sample(1:768,512))
PimaTrain <-PimaIndiansDiabetes[idtrain,]
Pimatest <-PimaIndiansDiabetes[-idtrain,]

m1 <- RWeka::J48(as.factor(as.character(PimaTrain$diabetes)) ~ .,
                 data = PimaTrain[,-c(9)],
                 control = RWeka::Weka_control(M = 10, C= 0.25))

哪个给出以下输出:

> m1
J48 pruned tree
------------------

glucose <= 154
|   age <= 28
|   |   glucose <= 118: neg (157.0/11.0)
|   |   glucose > 118
|   |   |   pressure <= 52: pos (10.0/3.0)
|   |   |   pressure > 52: neg (54.0/12.0)
|   age > 28
|   |   glucose <= 103: neg (54.0/10.0)
|   |   glucose > 103
|   |   |   mass <= 41.3: neg (129.0/55.0)
|   |   |   mass > 41.3: pos (12.0/1.0)
glucose > 154: pos (96.0/19.0)

Number of Leaves  :     7

Size of the tree :  13

基于这些规则,您将有7组(或叶子)。我正在寻找的是在测试数据 Pimatest 上应用这些规则(而不是重新训练决策树),以便实际上每个数据点都可以指定给新的7个组之一。变量 group

输出看起来像这样:

head(Pimatest)
   pregnant glucose pressure triceps insulin mass pedigree age diabetes group
3         8     183       64       0       0 23.3    0.672  32      pos     7
4         1      89       66      23      94 28.1    0.167  21      neg     1
6         5     116       74       0       0 25.6    0.201  30      neg     5
7         3      78       50      32      88 31.0    0.248  26      pos     1
8        10     115        0       0       0 35.3    0.134  29      neg     5
11        4     110       92       0       0 37.6    0.191  30      neg     5

我目前有一个有效的解决方案,该解决方案的编码真的很糟糕,所以这就是为什么我正在寻找一个解决该问题的优雅方案。

1 个答案:

答案 0 :(得分:2)

据我了解,您希望能够将每个点绑定到对该点进行分类的规则集。您可以通过将J48树转换为party树并使用partykit软件包中的工具来到达那里。

由于您没有为随机数生成器设置种子, 我们无法获得与您完全相同的测试/培训成绩。 我将设置种子以使我的示例可重复使用,但是即使 尽管我使用您的代码,但是我的树将与您的树稍有不同。

可复制的示例(主要是您的代码)

library(RWeka)
library("mlbench")
data("PimaIndiansDiabetes")

## Split in training and test (2/3 - 1/3)
set.seed(1234)
idtrain <- c(sample(1:768,512))
PimaTrain <-PimaIndiansDiabetes[idtrain,]
Pimatest <-PimaIndiansDiabetes[-idtrain,]

m1 <- RWeka::J48(as.factor(as.character(PimaTrain$diabetes)) ~ .,
                 data = PimaTrain[,-c(9)],
                 control = RWeka::Weka_control(M = 10, C= 0.25))
m1
J48 pruned tree
------------------
glucose <= 122
|   mass <= 26.8: neg (85.0/1.0)
|   mass > 26.8
|   |   pregnant <= 4: neg (137.0/19.0)
|   |   pregnant > 4
|   |   |   glucose <= 106: neg (44.0/10.0)
|   |   |   glucose > 106: pos (24.0/6.0)
glucose > 122
|   glucose <= 157
|   |   age <= 31
|   |   |   age <= 24: neg (30.0/5.0)
|   |   |   age > 24
|   |   |   |   pressure <= 72: pos (16.0/5.0)
|   |   |   |   pressure > 72: neg (22.0/5.0)
|   |   age > 31: pos (78.0/27.0)
|   glucose > 157: pos (76.0/13.0)

Number of Leaves  :     9
Size of the tree :      17

我的树有9片叶子,而不是7片。这是由于不同 为训练集选择的实例。现在我们准备获取规则。

library(partykit)
Pm1 = as.party(m1)
Pm1
Fitted party:
[1] root
|   [2] glucose <= 122
|   |   [3] mass <= 26.8: neg (n = 85, err = 1.2%)
|   |   [4] mass > 26.8
|   |   |   [5] pregnant <= 4: neg (n = 137, err = 13.9%)
|   |   |   [6] pregnant > 4
|   |   |   |   [7] glucose <= 106: neg (n = 44, err = 22.7%)
|   |   |   |   [8] glucose > 106: pos (n = 24, err = 25.0%)
|   [9] glucose > 122
|   |   [10] glucose <= 157
|   |   |   [11] age <= 31
|   |   |   |   [12] age <= 24: neg (n = 30, err = 16.7%)
|   |   |   |   [13] age > 24
|   |   |   |   |   [14] pressure <= 72: pos (n = 16, err = 31.2%)
|   |   |   |   |   [15] pressure > 72: neg (n = 22, err = 22.7%)
|   |   |   [16] age > 31: pos (n = 78, err = 34.6%)
|   |   [17] glucose > 157: pos (n = 76, err = 17.1%)

Number of inner nodes:    8
Number of terminal nodes: 9

这是与以前相同的树,但是具有标记节点的优点。我们还可以为每个叶子写出规则。

Pm1_rules = partykit:::.list.rules.party(Pm1)
Pm1_rules
                                                                       3 
                                         "glucose <= 122 & mass <= 26.8" 
                                                                       5 
                          "glucose <= 122 & mass > 26.8 & pregnant <= 4" 
                                                                       7 
          "glucose <= 122 & mass > 26.8 & pregnant > 4 & glucose <= 106" 
                                                                       8 
           "glucose <= 122 & mass > 26.8 & pregnant > 4 & glucose > 106" 
                                                                      12 
                "glucose > 122 & glucose <= 157 & age <= 31 & age <= 24" 
                                                                      14 
"glucose > 122 & glucose <= 157 & age <= 31 & age > 24 & pressure <= 72" 
                                                                      15 
 "glucose > 122 & glucose <= 157 & age <= 31 & age > 24 & pressure > 72" 
                                                                      16 
                             "glucose > 122 & glucose <= 157 & age > 31" 
                                                                      17 
                                         "glucose > 122 & glucose > 157" 

决策被写成规则。规则集的名称是 叶节点的数量。要获取用于测试点的规则,您只需要知道它最终位于哪个叶节点即可。但是用于聚会对象的predict方法将为您提供。

TestPred = predict(Pm1, newdata=Pimatest, type="node")
TestPred
  3   4   5   6   9  12  17  20  22  27  28  29  31  32  33  35  36  38  41  43 
 17   5  16   3  17  17   5   5   7  16   3  16   8  17   3   8   3   7  17   3 
 46  48  50  56  57  60  62  64  65  66  68  70  72  75  76  79  84  95  96  97 
 17   5   3   3  17   5  16  12   8   7   5  15  14   5   3  14   3  12  16   5 
...

我截断了输出,因为它太长了。现在,例如,
我们看到第一个测试点到达节点17。我们只需要使用它来索引规则集。但是需要一点注意。 predict返回的17是一个数字。规则集的名称是一个字符串,因此我们需要使用as.character进行转换。

Pm1_rules[as.character(TestPred[1])]
                             17 
"glucose > 122 & glucose > 157" 

我们确认:

Pimatest[1,]
  pregnant glucose pressure triceps insulin mass pedigree age diabetes
3        8     183       64       0       0 23.3    0.672  32      pos

是的,glucose > 122glucose > 157

您可以通过相同的方式获取其他测试点的规则。