有没有办法在R中内置一个随机森林并将其转换为SAS代码而不必输入getTree给出的所有if if elses?
我有30棵树,在getTree函数中有1900行
答案 0 :(得分:0)
这是我躺在那里应该帮助你开始的东西。到目前为止,只支持回归,但分类应该可以通过一些额外的工作来实现:
/* R code for exporting the randomForest object */
#Output dataset to csv for validation in SAS
write.csv(iris,file="C:/temp/iris.csv",row.names=FALSE)
#Train a 2-tree random forest for testing purposes
require(randomForest)
rf2 <- randomForest(iris[,-1],iris[,1],ntree=2)
# Get predictions and write to csv
write.csv(predict(rf2,iris),file="c:/temp/pred_rf2b.csv")
# Export factor levels
mydata <- iris
type <- sapply(mydata,class)
factors = type[type=="factor"]
output <- lapply(names(factors),function(x){
res <- data.frame(VarName=x,
Level=levels(mydata[,x]),
Number=1:nlevels(mydata[,x]))
return(res)
})
write.csv(do.call(rbind, output),file="c:/temp/factorlevels.csv", row.names=FALSE)
# Export all trees in one file
treeoutput <- lapply(1:rf2$ntree,function(x){
res <- getTree(rf2, x, labelVar=TRUE)
res$node <- seq.int(nrow(res))
res$treenum <- x
return(res)
})
write.csv(do.call(rbind, treeoutput),file="c:/temp/treeexport.csv", row.names=FALSE)
/*End of R code*/
/*Import into SAS, replacing . with _ so we have usable variable names*/
proc import
datafile = "c:\temp\treeexport.csv"
out = tree
dbms = csv
replace;
getnames = yes;
run;
data tree;
set tree;
SPLIT_VAR = translate(SPLIT_VAR,'_','.');
format SPLIT_POINT 8.3;
run;
proc import
datafile = "c:\temp\factorlevels.csv"
out = factorlevels
dbms = csv
replace;
getnames = yes;
run;
data _null_;
infile "c:\temp\iris.csv";
file "c:\temp\iris2.csv";
input;
if _n_ = 1 then _infile_=translate(_infile_,'_','.');
put _infile_;
run;
proc import
datafile = "c:\temp\iris2.csv"
out = iris
dbms = csv
replace;
getnames = yes;
run;
data _null_;
debug = 0;
type = "regression";
maxiterations = 10000;
file log;
if 0 then set tree factorlevels;
/*Hash to hold the whole tree*/
declare hash t(dataset:'tree');
rc = t.definekey('treenum');
rc = t.definekey('node');
rc = t.definedata(all:'yes');
rc = t.definedone();
/*Hash for looking up factor levels*/
declare hash fl(dataset:'factorlevels');
rc = fl.definekey('VARNAME','NUMBER');
rc = fl.definedata('LEVEL');
rc = fl.definedone();
do treenum = 1 by 1 while(t.find(key:treenum,key:1)=0);
/*Hash to hold the queue for current tree*/
length position qnode processed 8;
declare hash q(ordered:'a');
rc = q.definekey('position');
rc = q.definedata('qnode','position','processed');
rc = q.definedone();
declare hiter qi('q');
/*Hash for reverse queue lookup*/
declare hash q2();
rc = q2.definekey('qnode');
rc = q2.definedata('position');
rc = q2.definedone();
/*Load the starting node for the current tree*/
node = 1;
nodetype = 'L'; /*Track whether current node is a Left or Right node*/
complete = 0;
length treename $10;
treename = cats('tree',treenum);
do iteration = 1 by 1 while(complete = 0 and iteration <= maxiterations);
rc = t.find();
if debug then put "Processing node " node;
/*Logic for terminal nodes*/
if status = -1 then do;
if type ne "regression" then prediction = cats('"',prediction,'"');
put treename '=' prediction ';';
/*If current node is a right node, remove it from the queue*/
if nodetype = 'R' then do;
rc = q2.find();
if debug then put "Unqueueing node " qnode "in position " position;
processed = 1;
rc = q.replace();
end;
/*If the queue is empty, we are done*/
rc = qi.last();
do while(rc = 0 and processed = 1);
if position = 1 then complete = 1;
rc = qi.prev();
end;
/*Otherwise, process the most recently queued unprocessed node*/
if complete = 0 then do;
put "else ";
node = qnode;
nodetype = 'R';
end;
end;
/*Logic for split nodes - status ne -1*/
else do;
/*Add right_daughter to queue if present*/
position = q.num_items + 1;
qnode = right_daughter;
processed = 0;
rc = q.add();
rc = q2.add();
if debug then put "Queueing node " qnode "in position " position;
/*Check whether current split var is a (categorical) factor*/
rc = fl.find(key:split_var,key:1);
/*If yes, factor levels corresponding to 1s in the binary representation of the split point go left*/
if rc = 0 then do;
/*Get binary representation of split point (least significant bit first)*/
/*binaryw. format behaves very differently above width 58 - only 58 levels per factor supported here*/
/*This is sufficient as the R randomForest package only supports 53 levels per factor anyway*/
binarysplit = reverse(put(split_point,binary58.));
put 'if ' @;
j=0; /*Track how many levels have been encountered for this split var*/
do i = 1 to 64 while(rc = 0);
if i > 1 then rc = fl.find(key:split_var,key:i);
LEVEL = cats('"',LEVEL,'"');
if debug then put _all_;
if substr(binarysplit,i,1) = '1' then do;
if j > 0 then put ' or ' @;
put split_var ' = ' LEVEL @;
j + 1;
end;
end;
put 'then';
end;
/*If not, anything < split point goes to left child*/
else put "if " split_var "< " split_point 8.3 " then ";
if nodetype = 'R' then do;
qnode = node;
rc = q2.find();
if debug then put "Unqueueing node " qnode "in position " position;
processed = 1;
rc = q.replace();
end;
node = left_daughter;
nodetype = 'L';
end;
end;
/*End of tree function definition!*/
put ';';
/*Clear the queue between trees*/
rc = q.delete();
rc = q2.delete();
end;
/*We end up going 1 past the actual number of trees after the end of the do loop*/
treenum = treenum - 1;
if type = "regression" then do;
put 'RFprediction=(';
do i = 1 to treenum;
treename = cats('tree',i);
put treename +1 @;
if i < treenum then put '+' +1 @;
end;
put ')/' treenum ';';
end;
/*To do - write code to aggregate predictions from multiple trees for classification*/
stop;
run;
/*Sample of generated if-then-else code */
data predictions;
set iris;
if Petal_Length < 4.150 then
if Petal_Width < 1.050 then
if Petal_Width < 0.350 then
tree1 =4.91702127659574 ;
else
if Petal_Width < 0.450 then
tree1 =5.18333333333333 ;
else
if Species = "versicolor" then
tree1 =5.08888888888889 ;
else
tree1 =5.1 ;
else
if Sepal_Width < 2.550 then
tree1 =5.525 ;
else
if Petal_Length < 4.050 then
tree1 =5.8 ;
else
tree1 =5.63333333333333 ;
else
if Petal_Width < 1.950 then
if Sepal_Width < 3.050 then
if Species = "setosa" or Species = "virginica" then
if Petal_Length < 5.700 then
tree1 =6.05833333333333 ;
else
tree1 =7.2 ;
else
tree1 =6.176 ;
else
if Sepal_Width < 3.250 then
if Sepal_Width < 3.150 then
tree1 =6.62 ;
else
tree1 =6.66666666666667 ;
else
tree1 =6.3 ;
else
if Petal_Length < 6.050 then
if Petal_Width < 2.050 then
tree1 =6.275 ;
else
tree1 =6.65 ;
else
if Petal_Length < 6.550 then
tree1 =7.76666666666667 ;
else
tree1 =7.7 ;
;
if Petal_Width < 1.150 then
if Species = "setosa" then
tree2 =5.08947368421053 ;
else
tree2 =5.55714285714286 ;
else
if Species = "setosa" or Species = "versicolor" then
if Sepal_Width < 2.750 then
if Petal_Length < 4.450 then
tree2 =5.44 ;
else
tree2 =6.06666666666667 ;
else
if Petal_Width < 1.350 then
tree2 =5.85294117647059 ;
else
if Petal_Width < 1.750 then
if Petal_Width < 1.650 then
tree2 =6.3625 ;
else
tree2 =6.7 ;
else
tree2 =5.9 ;
else
if Petal_Length < 5.850 then
if Sepal_Width < 2.650 then
if Petal_Length < 4.750 then
tree2 =4.9 ;
else
if Sepal_Width < 2.350 then
tree2 =6 ;
else
if Sepal_Width < 2.550 then
tree2 =6.14 ;
else
tree2 =6.1 ;
else
tree2 =6.49166666666667 ;
else
if Petal_Length < 6.350 then
tree2 =7.125 ;
else
tree2 =7.775 ;
;
RFprediction=(
tree1 + tree2 )/2 ;
run;