我正在尝试使用R中的“rpart”包来构建生存树,我希望使用这个树来预测其他观察结果。
我知道有很多涉及rpart和预测的SO问题;但是,我无法找到任何解决问题(我认为)特定于使用带有“Surv”对象的rpart。
我的特殊问题涉及解释“预测”功能的结果。一个例子很有用:
library(rpart)
library(OIsurv)
# Make Data:
set.seed(4)
dat = data.frame(X1 = sample(x = c(1,2,3,4,5), size = 1000, replace=T))
dat$t = rexp(1000, rate=dat$X1)
dat$t = dat$t / max(dat$t)
dat$e = rbinom(n = 1000, size = 1, prob = 1-dat$t )
# Survival Fit:
sfit = survfit(Surv(t, event = e) ~ 1, data=dat)
plot(sfit)
# Tree Fit:
tfit = rpart(formula = Surv(t, event = e) ~ X1 , data = dat, control=rpart.control(minsplit=30, cp=0.01))
plot(tfit); text(tfit)
# Survival Fit, Broken by Node in Tree:
dat$node = as.factor(tfit$where)
plot( survfit(Surv(dat$t, event = dat$e)~dat$node) )
到目前为止一切顺利。我对这里发生的事情的理解是,rpart试图将指数生存曲线拟合到我的数据子集中。基于这种理解,我相信当我调用predict(tfit)
时,对于每次观察,我得到一个对应于该观察的指数曲线参数的数字。因此,例如,如果predict(fit)[1]
是.46,那么对于我原始数据集中的第一次观察,这意味着曲线由等式P(s) = exp(−λt)
给出,其中λ=.46
。
这似乎正是我想要的。对于每个观察(或任何新观察),我可以得到该观察在给定时间点内存活/死亡的预测概率。 (编辑:我意识到这可能是一种误解 - 这些曲线不会给出活着/死亡的概率,而是给出间隔存活的概率。但这并不会改变下面描述的问题。)
但是,当我尝试使用指数公式时......
# Predict:
# an attempt to use the rates extracted from the tree to
# capture the survival curve formula in each tree node.
rates = unique(predict(tfit))
for (rate in rates) {
grid= seq(0,1,length.out = 100)
lines(x= grid, y= exp(-rate*(grid)), col=2)
}
我在这里做的是以与生存树相同的方式拆分数据集,然后使用survfit
为每个分区绘制非参数曲线。那是黑线。我还绘制了与插入(我认为是)'rate'参数(我认为是)生存指数公式的结果对应的行。
我知道非参数和参数拟合不一定相同,但这似乎不止于此:我似乎需要扩展我的X变量或其他东西。
基本上,我似乎并不理解rpart / survival在引擎盖下使用的公式。谁能帮助我从(1)rpart模型得到(2)任意观察的生存方程?
答案 0 :(得分:6)
生存数据按指数内部进行缩放,以便根节点中的预测速率始终固定为1.000
。然后,predict()
方法报告的预测总是相对于根节点中的存活,即,某个因子更高或更低。有关详细信息,请参阅vignette("longintro", package = "rpart")
中的第8.4节。在任何情况下,报告的Kaplan-Meier曲线都与rpart
小插图中报告的曲线完全一致。
如果您想直接获取树中Kaplan-Meier曲线的图并获得预测的中位生存时间,您可以将rpart
树强制转换为constparty
树,如{ {1}}包:
partykit
打印输出显示中位存活时间,并显示相应的Kaplan-Meier曲线。也可以使用library("partykit")
(tfit2 <- as.party(tfit))
## Model formula:
## Surv(t, event = e) ~ X1
##
## Fitted party:
## [1] root
## | [2] X1 < 2.5
## | | [3] X1 < 1.5: 0.192 (n = 213)
## | | [4] X1 >= 1.5: 0.082 (n = 213)
## | [5] X1 >= 2.5: 0.037 (n = 574)
##
## Number of inner nodes: 2
## Number of terminal nodes: 3
##
plot(tfit2)
方法将predict()
参数分别设置为type
和"response"
来获取两者。
"prob"
作为predict(tfit2, type = "response")[1]
## 5
## 0.03671885
predict(tfit2, type = "prob")[[1]]
## Call: survfit(formula = y ~ 1, weights = w, subset = w > 0)
##
## records n.max n.start events median 0.95LCL 0.95UCL
## 574.0000 574.0000 574.0000 542.0000 0.0367 0.0323 0.0408
生存树的替代方法,您还可以考虑基于rpart
中的条件推理(使用logrank得分)或使用常规{{1}的完全参数生存树的非参数生存树。来自ctree()
包的基础设施。