R randomforest包到Base SAS

时间:2017-06-21 16:07:42

标签: r sas random-forest

有没有办法在R中内置一个随机森林并将其转换为SAS代码而不必输入getTree给出的所有if if elses?

我有30棵树,在getTree函数中有1900行

1 个答案:

答案 0 :(得分:0)

这是我躺在那里应该帮助你开始的东西。到目前为止,只支持回归,但分类应该可以通过一些额外的工作来实现:

/* R code for exporting the randomForest object */
#Output dataset to csv for validation in SAS
write.csv(iris,file="C:/temp/iris.csv",row.names=FALSE)

#Train a 2-tree random forest for testing purposes
require(randomForest)
rf2 <- randomForest(iris[,-1],iris[,1],ntree=2)

# Get predictions and write to csv
write.csv(predict(rf2,iris),file="c:/temp/pred_rf2b.csv")

# Export factor levels 
mydata <- iris
type <- sapply(mydata,class)
factors = type[type=="factor"]
output <- lapply(names(factors),function(x){
  res <- data.frame(VarName=x, 
                    Level=levels(mydata[,x]), 
                    Number=1:nlevels(mydata[,x]))
  return(res)
})

write.csv(do.call(rbind, output),file="c:/temp/factorlevels.csv", row.names=FALSE)

# Export all trees in one file
treeoutput <- lapply(1:rf2$ntree,function(x){
  res <- getTree(rf2, x, labelVar=TRUE)
  res$node <- seq.int(nrow(res))
  res$treenum <- x
  return(res)
})

write.csv(do.call(rbind, treeoutput),file="c:/temp/treeexport.csv", row.names=FALSE)
/*End of R code*/

/*Import into SAS, replacing . with _ so we have usable variable names*/

proc import
  datafile = "c:\temp\treeexport.csv"
  out = tree
  dbms = csv
  replace;
  getnames = yes;
run;

data tree;
set tree;
SPLIT_VAR = translate(SPLIT_VAR,'_','.');
format SPLIT_POINT 8.3;
run;

proc import 
  datafile = "c:\temp\factorlevels.csv"
  out = factorlevels
  dbms = csv
  replace;
  getnames = yes;
run;

data _null_;
  infile "c:\temp\iris.csv";
  file "c:\temp\iris2.csv";
  input;
  if _n_ = 1 then _infile_=translate(_infile_,'_','.');
  put _infile_;
run;

proc import 
  datafile = "c:\temp\iris2.csv"
  out = iris
  dbms = csv
  replace;
  getnames = yes;
run;


data _null_;
  debug = 0;
  type = "regression";
  maxiterations = 10000;
  file log;
  if 0 then set tree factorlevels;
  /*Hash to hold the whole tree*/
  declare hash t(dataset:'tree');
  rc = t.definekey('treenum');
  rc = t.definekey('node');
  rc = t.definedata(all:'yes');
  rc = t.definedone();

  /*Hash for looking up factor levels*/
  declare hash fl(dataset:'factorlevels');
  rc = fl.definekey('VARNAME','NUMBER');
  rc = fl.definedata('LEVEL');
  rc = fl.definedone();

  do treenum = 1 by 1 while(t.find(key:treenum,key:1)=0);
    /*Hash to hold the queue for current tree*/
    length position qnode processed 8;
    declare hash q(ordered:'a');
    rc = q.definekey('position');
    rc = q.definedata('qnode','position','processed');
    rc = q.definedone();
    declare hiter qi('q');
    /*Hash for reverse queue lookup*/
    declare hash q2();
    rc = q2.definekey('qnode');
    rc = q2.definedata('position');
    rc = q2.definedone();

    /*Load the starting node for the current tree*/
    node = 1;
    nodetype = 'L'; /*Track whether current node is a Left or Right node*/
    complete = 0;
    length treename $10;
    treename = cats('tree',treenum);

    do iteration = 1 by 1 while(complete = 0 and iteration <= maxiterations);
      rc = t.find();
      if debug then put "Processing node " node; 

      /*Logic for terminal nodes*/
      if status = -1 then do;
        if type ne "regression" then prediction = cats('"',prediction,'"');
        put treename '=' prediction ';';
        /*If current node is a right node, remove it from the queue*/
        if nodetype = 'R' then do;
          rc = q2.find();
          if debug then put "Unqueueing node " qnode "in position " position;   
          processed = 1;
          rc = q.replace();
        end;
        /*If the queue is empty, we are done*/
        rc = qi.last();
        do while(rc = 0 and processed = 1);
          if position = 1 then complete = 1;
          rc = qi.prev();
        end;
        /*Otherwise, process the most recently queued unprocessed node*/
        if complete = 0 then do;
          put "else ";
          node = qnode;
          nodetype = 'R';
        end;
      end;

      /*Logic for split nodes - status ne -1*/
      else do;
        /*Add right_daughter to queue if present*/
        position = q.num_items + 1;
        qnode = right_daughter;
        processed = 0;
        rc = q.add();
        rc = q2.add();
        if debug then put "Queueing node " qnode "in position " position; 

        /*Check whether current split var is a (categorical) factor*/
        rc = fl.find(key:split_var,key:1);
        /*If yes, factor levels corresponding to 1s in the binary representation of the split point go left*/
        if rc = 0 then do;
          /*Get binary representation of split point (least significant bit first)*/
          /*binaryw. format behaves very differently above width 58 - only 58 levels per factor supported here*/
          /*This is sufficient as the R randomForest package only supports 53 levels per factor anyway*/
          binarysplit = reverse(put(split_point,binary58.));
          put 'if ' @;
          j=0; /*Track how many levels have been encountered for this split var*/
          do i = 1 to 64 while(rc = 0);
            if i > 1 then rc = fl.find(key:split_var,key:i);
            LEVEL = cats('"',LEVEL,'"');
            if debug then put _all_;
            if substr(binarysplit,i,1) = '1' then do;
              if j > 0 then put ' or ' @;
              put split_var ' = ' LEVEL @;
              j + 1;
            end;
          end;
          put 'then';
        end;
        /*If not, anything < split point goes to left child*/
        else put "if " split_var "< " split_point 8.3 " then ";
        if nodetype = 'R' then do;
          qnode = node;
          rc = q2.find();
          if debug then put "Unqueueing node " qnode "in position " position;   
          processed = 1;
          rc = q.replace();
        end;
        node = left_daughter;
        nodetype = 'L';
      end;
    end;
    /*End of tree function definition!*/
    put ';';
    /*Clear the queue between trees*/
    rc = q.delete();
    rc = q2.delete();
  end;

  /*We end up going 1 past the actual number of trees after the end of the do loop*/
  treenum = treenum - 1;

  if type = "regression" then do;
    put 'RFprediction=(';
    do i = 1 to treenum;
      treename = cats('tree',i);
      put treename +1 @;
      if i < treenum then put '+' +1 @;
    end;
    put ')/' treenum ';';
  end;

  /*To do - write code to aggregate predictions from multiple trees for classification*/

  stop;
run;


/*Sample of generated if-then-else code */

data predictions;
  set iris;
if Petal_Length <    4.150 then
if Petal_Width <    1.050 then
if Petal_Width <    0.350 then
tree1 =4.91702127659574 ;
else
if Petal_Width <    0.450 then
tree1 =5.18333333333333 ;
else
if Species  = "versicolor" then
tree1 =5.08888888888889 ;
else
tree1 =5.1 ;
else
if Sepal_Width <    2.550 then
tree1 =5.525 ;
else
if Petal_Length <    4.050 then
tree1 =5.8 ;
else
tree1 =5.63333333333333 ;
else
if Petal_Width <    1.950 then
if Sepal_Width <    3.050 then
if Species  = "setosa"  or Species  = "virginica" then
if Petal_Length <    5.700 then
tree1 =6.05833333333333 ;
else
tree1 =7.2 ;
else
tree1 =6.176 ;
else
if Sepal_Width <    3.250 then
if Sepal_Width <    3.150 then
tree1 =6.62 ;
else
tree1 =6.66666666666667 ;
else
tree1 =6.3 ;
else
if Petal_Length <    6.050 then
if Petal_Width <    2.050 then
tree1 =6.275 ;
else
tree1 =6.65 ;
else
if Petal_Length <    6.550 then
tree1 =7.76666666666667 ;
else
tree1 =7.7 ;
;
if Petal_Width <    1.150 then
if Species  = "setosa" then
tree2 =5.08947368421053 ;
else
tree2 =5.55714285714286 ;
else
if Species  = "setosa"  or Species  = "versicolor" then
if Sepal_Width <    2.750 then
if Petal_Length <    4.450 then
tree2 =5.44 ;
else
tree2 =6.06666666666667 ;
else
if Petal_Width <    1.350 then
tree2 =5.85294117647059 ;
else
if Petal_Width <    1.750 then
if Petal_Width <    1.650 then
tree2 =6.3625 ;
else
tree2 =6.7 ;
else
tree2 =5.9 ;
else
if Petal_Length <    5.850 then
if Sepal_Width <    2.650 then
if Petal_Length <    4.750 then
tree2 =4.9 ;
else
if Sepal_Width <    2.350 then
tree2 =6 ;
else
if Sepal_Width <    2.550 then
tree2 =6.14 ;
else
tree2 =6.1 ;
else
tree2 =6.49166666666667 ;
else
if Petal_Length <    6.350 then
tree2 =7.125 ;
else
tree2 =7.775 ;
;
RFprediction=(
tree1  + tree2  )/2 ;
run;