森林“树”中的权重加起来超过了样本量

时间:2018-06-26 11:05:16

标签: r random-forest party

这是我第一次在这里问问题,所以请客气...我很抱歉这不是问这个问题的正确位置(R编程和stats问题之间的界线可能有点模糊给我)-如果是这样,我会很乐意尝试stackexchange。

我正在R(聚会套餐)中运行cforest以预测数值响应变量,并使用@Marco Sandri在此处建议的出色的“ get_cTree”解决方法:https://stackoverflow.com/a/34534978/9989544 生成一棵树以尝试了解它用于拆分的规则(即使我的主要重点是可变的重要性,这对我也很有趣)。

我期望所有节点上的权重总和等于我的总样本量,如果我只运行一个'ctree',就会发生这种情况。

但是,当使用Marco Sandri的get_cTree代码时,实际上发生的是多对节点权重之和等于我的样本大小,而其余权重根本不等于我的样本大小。总体重量大于我的总样本量。

尝试从条件森林中取出一棵树是自然的结果吗?也就是说,它不是将数据真正分割为单个节点吗? -或者这可以通过编程解决?

这里是一个示例(Marco Sandri的get_cTree代码)。对于虹膜数据集,n = 150。我为cforest获取的节点的权重之和为566,而使用ctree(聚会套餐)则为150。

library(party)

update_tree <- function(x, dt) {
  x <- update_weights(x, dt)
  if(!x$terminal) {
    x$left <- update_tree(x$left, dt)
    x$right <- update_tree(x$right, dt)   
  } 
  x
}

update_weights <- function(x, dt) {
  splt <- x$psplit
  spltClass <- attr(splt,"class")
  spltVarName <- splt$variableName
  spltVar <- dt[,spltVarName]
  spltVarLev <- levels(spltVar)
  if (!is.null(spltClass)) {
    if (spltClass=="nominalSplit") {
      attr(x$psplit$splitpoint,"levels") <- spltVarLev   
      filt <- spltVar %in% spltVarLev[as.logical(x$psplit$splitpoint)] 
    } else {
      filt <- (spltVar <= splt$splitpoint)
    }
    x$left$weights <- as.numeric(filt)
    x$right$weights <- as.numeric(!filt)
  }
  x
}

get_cTree <- function(cf, k=1) {
  dt <- cf@data@get("input")
  tr <- party:::prettytree(cf@ensemble[[k]], names(dt))
  tr_updated <- update_tree(tr, dt)
  new("BinaryTree", tree=tr_updated, data=cf@data, responses=cf@responses, 
      cond_distr_response=cf@cond_distr_response, predict_response=cf@predict_response)
}

attach(iris)

SepalLength <- as.numeric(iris$Sepal.Length)

SepalWidth <- as.numeric(iris$Sepal.Width)

PetalLength <- as.numeric(iris$Petal.Length)

PetalWidth <- as.numeric(iris$Petal.Width)

Species <- as.factor(iris$Species)

mtry=ceiling(sqrt(4))

set.seed(1)

iris_cforest <- cforest(PetalLength~SepalLength+SepalWidth+PetalWidth+Species,controls=cforest_unbiased(ntree=1000,mtry=mtry))

iristree <- get_cTree(iris_cforest)

iristree

plot(iristree)

set.seed(1)

iris_ctree <- ctree(PetalLength~SepalLength+SepalWidth+PetalWidth+Species)

iris_ctree

plot(iris_ctree)

0 个答案:

没有答案