R randomForest投票决胜局

时间:2011-12-07 20:45:14

标签: r machine-learning scoring random-forest

有没有人知道R randomForest包用于解决分类关系的机制是什么 - 即当树在两个或更多类中以相同的票数结束时?

文件说领带是随机打破的。但是,当您在一组数据上训练模型,然后使用一组验证数据对该模型进行多次评分时,并列的类决策不是50/50。

cnum = vector("integer",1000)
for (i in 1:length(cnum)){
  cnum[i] = (as.integer(predict(model,val_x[bad_ind[[1]],])))
}
cls = unique(cnum)
for (i in 1:length(cls)){
  print(length(which(cnum == cls[i])))
}

其中model是randomForest对象,而bad_ind只是已绑定类投票的要素向量的索引列表。在我的测试用例中,使用上面的代码,两个绑定类之间的分布更接近90/10。

此外,使用奇数树的建议通常不适用于第三类拉一些选票而另外两个类别的平局。

这些与投票结果相关的案件最终不应该达到50/50吗?

更新 由于森林训练的随机性,很难提供一个例子,但是下面的代码(对于斜坡而言)应该最终产生森林无法确定明显赢家的例子。当关系破裂时,我的测试运行显示66%/ 33%的分布 - 我预计这将是50%/ 50%。

library(randomForest)
x1 = runif(200,-4,4)
x2 = runif(200,-4,4)
x3 = runif(1000,-4,4)
x4 = runif(1000,-4,4)
y1 = dnorm(x1,mean=0,sd=1)
y2 = dnorm(x2,mean=0,sd=1)
y3 = dnorm(x3,mean=0,sd=1)
y4 = dnorm(x4,mean=0,sd=1)
train = data.frame("v1"=y1,"v2"=y2)
val = data.frame("v1"=y3,"v2"=y4)
tlab = vector("integer",length(y1))
tlab_ind = sample(1:length(y1),length(y1)/2)
tlab[tlab_ind]= 1
tlab[-tlab_ind] = 2
tlabf = factor(tlab)
vlab = vector("integer",length(y3))
vlab_ind = sample(1:length(y3),length(y3)/2)
vlab[vlab_ind]= 1
vlab[-vlab_ind] = 2
vlabf = factor(vlab)
mm <- randomForest(x=train,y=tlabf,ntree=100)
out1 <- predict(mm,val)
out2 <- predict(mm,val)
out3 <- predict(mm,val)
outv1 <- predict(mm,val,norm.votes=FALSE,type="vote")
outv2 <- predict(mm,val,norm.votes=FALSE,type="vote")
outv3 <- predict(mm,val,norm.votes=FALSE,type="vote")

(max(as.integer(out1)-as.integer(out2)));(min(as.integer(out1)-as.integer(out2)))
(max(as.integer(out2)-as.integer(out3)));(min(as.integer(out2)-as.integer(out3)))
(max(as.integer(out1)-as.integer(out3)));(min(as.integer(out1)-as.integer(out3)))

bad_ind = vector("list",0)
for (i in 1:length(out1)) {
#for (i in 1:100) {
  if (out1[[i]] != out2[[i]]){
    print(paste(i,out1[[i]],out2[[i]],sep = ";    "))
    bad_ind = append(bad_ind,i)
  }
}

for (j in 1:length(bad_ind)) {
  cnum = vector("integer",1000)
  for (i in 1:length(cnum)) {
    cnum[[i]] = as.integer(predict(mm,val[bad_ind[[j]],]))
  }
  cls = unique(cnum)
  perc_vals = vector("integer",length(cls))
  for (i in 1:length(cls)){
    perc_vals[[i]] = length(which(cnum == cls[i]))
  }
  cat("for feature vector ",bad_ind[[j]]," the class distrbution is: ",perc_vals[[1]]/sum(perc_vals),"/",perc_vals[[2]]/sum(perc_vals),"\n")
}

更新 这应该在randomForest版本4.6-3中修复。

3 个答案:

答案 0 :(得分:1)

如果没有一个完整的例子,很难说这是否是唯一的错误,但上面包含的代码的一个明显问题是你没有复制模型拟合步骤 - 只有预测步骤。当您适合模型时,可以选择任意打破平局,因此如果您不重做该部分,您的predict()调用将继续为同一个类提供更高的概率/投票。

请尝试使用此示例,它会正确演示您所需的行为:

library(randomForest)
df = data.frame(class=factor(rep(1:2, each=5)), X1=rep(c(1,3), each=5), X2=rep(c(2,3), each=5))
fitTie <- function(df) {
  df.rf <- randomForest(class ~ ., data=df)
  predict(df.rf, newdata=data.frame(X1=1, X2=3), type='vote')[1]
}
> df
   class X1 X2
1      1  1  2
2      1  1  2
3      1  1  2
4      1  1  2
5      1  1  2
6      2  3  3
7      2  3  3
8      2  3  3
9      2  3  3
10     2  3  3

> mean(replicate(10000, fitTie(df)))
[1] 0.49989

答案 1 :(得分:1)

我认为这种情况正在发生,因为你有这么少的关系。与掷硬币10次相同的问题,你不能保证在5头5尾巴上卷起来。

在下面的情况1中,领带均匀分开,每个班级1:1。在案例2,3:6中。

> out1[out1 != out2]
 52 109 144 197 314 609 939 950 
  2   2   1   2   2   1   1   1 

> out1[out1 != out3]
 52 144 146 253 314 479 609 841 939 
  2   1   2   2   2   2   1   2   1 

更改为更大的数据集:

x1 = runif(2000,-4,4)
x2 = runif(2000,-4,4)
x3 = runif(10000,-4,4)
x4 = runif(10000,-4,4)

我明白了:

> sum(out1[out1 != out2] == 1)
[1] 39
> sum(out1[out1 != out2] == 2)
[1] 41

> sum(out1[out1 != out3] == 1)
[1] 30
> sum(out1[out1 != out3] == 2)
[1] 31

正如所料,除非我误解你的代码。


修改

哦,我明白了。您正在重新运行具有关系的案例,并期望它们被打破50/50,即:sum(cnum == 1)大约等于sum(cnum == 2)。使用这种方法可以更快地测试:

> for (j in 1:length(bad_ind)) {
+   mydata= data.frame("v1"=0, "v2"=0)
+   mydata[rep(1:1000000),] = val[bad_ind[[j]],]
+   outpred = predict(mm,mydata)
+   print(sum(outpred==1) / sum(outpred==2))
+ }
[1] 0.5007849
[1] 0.5003278
[1] 0.4998868
[1] 0.4995651

看来你是对的,它打破了第2阶段的关系,比第1阶段更频繁。

答案 2 :(得分:1)

这应该在randomForest版本4.6-3中修复。