我正在修改Brieman的随机森林程序(我不知道C / C ++),所以我在R中从头开始编写自己的RF变体我的程序和标准程序之间的区别主要在于如何计算分割点和终端节点中的值 - 一旦我在森林中有一棵树,它就可以被认为与典型RF中的树非常相似算法
我的问题是它的预测速度非常慢,而且我在思考如何加快速度方面遇到了困难。
测试树对象已链接here,某些测试数据已链接here。您可以直接下载,或者如果安装了repmis
,则可以在下面加载它。它们被称为testtree
和sampx
。
library(repmis)
testtree <- source_DropboxData(file = "testtree", key = "sfbmojc394cnae8")
sampx <- source_DropboxData(file = "sampx", key = "r9imf317hpflpsx")
编辑:不知何故,我还没有真正学会如何使用github。我已将所需文件上传到存储库here - 道歉,我现在无法弄清楚如何获得固定链接...
这里有一点关于对象的结构:
1> summary(testtree)
Length Class Mode
nodes 7 -none- list
minsplit 1 -none- numeric
X 29 data.frame list
y 6719 -none- numeric
weights 6719 -none- numeric
oob 2158 -none- numeric
1> summary(testtree$nodes)
Length Class Mode
[1,] 4 -none- list
[2,] 8 -none- list
[3,] 8 -none- list
[4,] 7 -none- list
[5,] 7 -none- list
[6,] 7 -none- list
[7,] 7 -none- list
1> summary(testtree$nodes[[1]])
Length Class Mode
y 6719 -none- numeric
output 1 -none- numeric
Terminal 1 -none- logical
children 2 -none- numeric
1> testtree$nodes[[1]][2:4]
$output
[1] 40.66925
$Terminal
[1] FALSE
$children
[1] 2 3
1> summary(testtree$nodes[[2]])
Length Class Mode
y 2182 -none- numeric
parent 1 -none- numeric
splitvar 1 -none- character
splitpoint 1 -none- numeric
handedness 1 -none- character
children 2 -none- numeric
output 1 -none- numeric
Terminal 1 -none- logical
1> testtree$nodes[[2]][2:8]
$parent
[1] 1
$splitvar
[1] "bizrev_allHH"
$splitpoint
25%
788.875
$handedness
[1] "Left"
$children
[1] 4 5
$output
[1] 287.0085
$Terminal
[1] FALSE
output
是该节点的返回值 - 我希望其他一切都是不言自明的。
我写的预测功能有效,但速度太慢了。基本上它&#34;走下树&#34;,通过观察观察:
predict.NT = function(tree.obj, newdata=NULL){
if (is.null(newdata)){X = tree.obj$X} else {X = newdata}
tree = tree.obj$nodes
if (length(tree)==1){#Return the mean for a stump
return(rep(tree[[1]]$output,length(X)))
}
pred = apply(X = newdata, 1, godowntree, nn=1, tree=tree)
return(pred)
}
godowntree = function(x, tree, nn = 1){
while (tree[[nn]]$Terminal == FALSE){
fb = tree[[nn]]$children[1]
sv = tree[[fb]]$splitvar
sp = tree[[fb]]$splitpoint
if (class(sp)=='factor'){
if (as.character(x[names(x) == sv]) == sp){
nn<-fb
} else{
nn<-fb+1
}
} else {
if (as.character(x[names(x) == sv]) < sp){
nn<-fb
} else{
nn<-fb+1
}
}
}
return(tree[[nn]]$output)
}
问题是它真的很慢(当你考虑非样本树更大,我需要做很多次)时,即使是一棵简单的树:
library(microbenchmark)
microbenchmark(predict.NT(testtree,sampx))
Unit: milliseconds
expr min lq mean median uq
predict.NT(testtree, sampx) 16.19845 16.36351 17.37022 16.54396 17.07274
max neval
40.4691 100
我今天从某人那里得到了一个想法,我可以编写一个函数工厂类型的函数(即:生成闭包的函数,我只是在学习),将我的树分解成一堆嵌套的if / else语句。然后我可以通过它发送数据,这可能比一遍又一遍地从树中提取数据更快。我还没有编写函数函数生成函数,但我亲自编写了我从中得到的那种输出,并测试了它:
predictif = function(x){
if (x[names(x) == 'bizrev_allHH'] < 788.875){
if (x[names(x) == 'male_head'] <.872){
return(548)
} else {
return(165)
}
} else {
if (x[names(x) == 'nondurable_exp_mo'] < 4190.965){
return(-283)
}else{
return(-11.4)
}
}
}
predictif.NT = function(tree.obj, newdata=NULL){
if (is.null(newdata)){X = tree.obj$X} else {X = newdata}
tree = tree.obj$nodes
if (length(tree)==1){#Return the mean for a stump
return(rep(tree[[1]]$output,length(X)))
}
pred = apply(X = newdata, 1, predictif)
return(pred)
}
microbenchmark(predictif.NT(testtree,sampx))
Unit: milliseconds
expr min lq mean median uq
predictif.CT(testtree, sampx) 12.77701 12.97551 14.21417 13.18939 13.67667
max neval
30.48373 100
快一点,但不多!
我真的很感激任何提高速度的想法!或者,如果答案是&#34;如果不将其转换为C / C ++,那么你真的无法获得更快的速度,这也是有价值的信息(特别是如果你给出的话)我有一些关于为什么会这样的信息。
虽然我当然很欣赏R中的答案,但伪代码的答案也会非常有用。
谢谢!
答案 0 :(得分:5)
加速功能的秘诀是矢量化。不是单独对每一行执行所有操作,而是一次在所有行上执行它们。
让我们重新考虑您的predictif
功能
predictif = function(x){
if (x[names(x) == 'bizrev_allHH'] < 788.875){
if (x[names(x) == 'male_head'] <.872){
return(548)
} else {
return(165)
}
} else {
if (x[names(x) == 'nondurable_exp_mo'] < 4190.965){
return(-283)
}else{
return(-11.4)
}
}
}
这是一种缓慢的方法,因为它在每个单独的实例上应用所有这些操作。函数调用,if语句,尤其是names(x) == 'bizrev_allHH'
之类的操作都会产生一些开销,当你为每个实例执行操作时,这些开销会增加。
相比之下,简单地比较两个数字非常快!因此,请编写上述的矢量化版本。
predictif_fast <- function(newdata) {
n1 <- newdata$bizrev_allHH < 788.875
n2 <- newdata$male_head < .872
n3 <- newdata$nondurable_exp_mo < 4190.965
ifelse(n1, ifelse(n2, 548.55893, 165.15537),
ifelse(n3, -283.35145, -11.40185))
}
注意,这个非常重要,这个函数是而不是传递一个实例。它意味着传递您的整个新数据。这是有效的,因为<
和ifelse
操作都是向量化的:当给定向量时,它们会返回一个向量。
让我们比较你的功能和这个新功能:
> microbenchmark(predictif.NT(testtree, sampx),
predictif_fast(sampx))
Unit: microseconds
expr min lq mean median uq
predictif.NT(testtree, sampx) 12106.419 13144.2390 14684.46 13719.406 14593.1565
predictif_fast(sampx) 189.093 213.6505 263.74 246.192 260.7895
max neval cld
79136.335 100 b
2344.059 100 a
请注意,我们通过矢量化获得了50倍的加速。
顺便提一下,有可能将速度提高很多(如果你通过索引获得聪明的话,有更快的ifelse
替代方案),但是从“在每一行上执行一个函数”到“执行操作”的整体切换整个矢量“让你获得最大的加速。
这并不能完全解决您的问题,因为您需要在常规树上执行这些矢量化操作,而不仅仅是在这个特定的树上执行。我不会为您解决一般版本,但考虑到您可以重写godowntree
函数,以便它占用整个数据框并在完整版本上执行操作,而不仅仅是一个。然后,不要使用if
分支,而是保留每个实例当前所在子项的向量。