我有以下代码列出了ctree
的所有终端节点。我想为训练集中的每条记录添加(airq
,在这种情况下,我对其进行了训练)其终端节点号。因此,我将添加airq
列调用TN
(终端节点),其中包含其终端节点号。
CtreePathFunc <- function (ct, data) {
ResulTable <- data.frame(Node = character(), Path = character())
for(Node in unique(where(ct))){
# Taking all possible non-Terminal nodes that are smaller than the selected terminal node
NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])
# Getting the weigths for that node
NodeWeights <- nodes(ct, Node)[[1]]$weights
# Finding the path
Path <- NULL
for (i in NonTerminalNodes){
if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
}
# Finding the splitting creteria for that path
Path2 <- SB <- NULL
for(i in 1:length(Path)){
if(i == length(Path)) {
n <- nodes(ct, Node)[[1]]
} else {n <- nodes(ct, Path[i + 1])[[1]]}
if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){
SB <- "<="
} else {SB <- ">"}
Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
SB,
as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
collapse = ", ")
}
# Output
ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
}
return(ResulTable)
}
library(party)
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq, controls = ctree_control(maxsurrogate = 3))
Result <- CtreePathFunc(ct, airq)
> Result
Node Path
1 5 Temp <= 82, Wind > 6.9, Temp <= 77
2 3 Temp <= 82, Wind <= 6.9
3 6 Temp <= 82, Wind > 6.9, Temp > 77
4 9 Temp > 82, Wind > 10.3
5 8 Temp > 82, Wind <= 10.3