如何使用随机森林获得课程的重要性?

时间:2015-12-03 23:07:31

标签: variables random-forest decision-tree

我在我的数据集中使用randomForest包来进行分类,但是使用importance命令我只能获得变量的重要性。那么,如果我想通过特定类别的变量来获得变量的重要性?就像区域变量中的特定位置一样,该区域对总数的影响程度。我想在变压器的每一个班级都有一个假人,但我不知道这是不是一个好主意。

1 个答案:

答案 0 :(得分:0)

我认为你的意思是"特定类别的变量的变量重要性"。这还没有实现,但我想这将是可能的,有意义的,也许是有用的。当然,对于只有两个类别的变量,它没有意义。

我会实现如下: 火车模型 - >计算袋外预测性能(OOB-cv1) - >通过特定变量对特定类别进行置换(将此类别随机重新分配给其他类别,按其他类别流行度加权) - >重新计算袋外预测性能(OOB-cv2) - >从OOB-cv2中减去OOB-cv1

然后我编写了一个实现分类特定变量重要性的函数。

library(randomForest)

#Create some classification problem, with mixed categorical and numeric vars
#Cat A of var 1, cat B of var 2 and Cat C of var 3 influence class the most.
X.cat = replicate(3,sample(c("A","B","C"),600,rep=T))
X.val = replicate(2,rnorm(600))
y.cat = 3*(X.cat[,1]=="A") + 3*(X.cat[,2]=="B") + 3*(X.cat[,3]=="C")
y.cat.err = y.cat+rnorm(600)
y.lim = quantile(y.cat.err,c(1/3,2/3))
y.class = apply(replicate(2,y.cat.err),1,function(x) sum(x>y.lim)+1)
y.class = factor(y.class,labels=c("ann","bob","chris"))  
X.full = data.frame(X.cat,X.val)
X.full[1:3] = lapply(X.full[1:3],as.factor)

#train forest
rf=randomForest(X.full,y.class,keep.inbag=T,replace=T)

#make function to compute crovalidated classification error
oobErr = function(rf,X) {
  preds = predict(rf,X,type="vote",predict.all = T)$individual
  preds[rf$inbag!=0]=NA
  oob.pred = apply(preds,1,function(x) {
    tabx=sort(table(x),dec=T)
    majority.vote = names(tabx)[1]
  })
  return(mean(as.character(rf$y)!=oob.pred))
}

#make function to iterate all categories of categorical variables
#and compute change of OOB class error due to permutation of category
catVar = function(rf,X,nPerm=2) {
  ref = oobErr(rf,X)
  catVars = which(rf$forest$ncat>1)
  lapply(catVars, function(iVar) {
    catImp = replicate(nPerm,{
      sapply(levels(X[[iVar]]), function(thisCat) {
        thisCat.ind = which(thisCat==X[[iVar]])
        X[thisCat.ind,iVar] = head(sample(X[[iVar]]),length(thisCat.ind))
        varImp = oobErr(rf,X)-ref
      })
    })
    if(nPerm==1) catImp else apply(catImp,1,mean) 
  })
}

#try it out
out = catVar(rf,X.full,nPerm=4)
print(out) #seems like it works as it should

$X1
      A       B       C 
0.14000 0.07125 0.06875 

$X2
         A          B          C 
0.07458333 0.16083333 0.07666667 

$X3
         A          B          C 
0.05333333 0.08083333 0.15375000