如何从python中的拟合scikit-survival模型解释.predict()的输出?

时间:2017-11-13 22:05:22

标签: python machine-learning survival-analysis scikit-survival

我很困惑如何在scikit-survival中从拟合的.predict模型中解释CoxnetSurvivalAnalysis的输出。我已阅读笔记本Intro to Survival Analysis in scikit-survival和API参考,但无法找到解释。以下是导致我混淆的最小例子:

import pandas as pd
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.linear_model import CoxnetSurvivalAnalysis

# load data
data_X, data_y = load_veterans_lung_cancer()

# one-hot-encode categorical columns in X
categorical_cols = ['Celltype', 'Prior_therapy', 'Treatment']

X = data_X.copy()
for c in categorical_cols:
    dummy_matrix = pd.get_dummies(X[c], prefix=c, drop_first=False)
    X = pd.concat([X, dummy_matrix], axis=1).drop(c, axis=1)

# display final X to fit Cox Elastic Net model on
del data_X
print(X.head(3))

所以这里是进入模型的X:

   Age_in_years  Celltype  Karnofsky_score  Months_from_Diagnosis  \
0          69.0  squamous             60.0                    7.0   
1          64.0  squamous             70.0                    5.0   
2          38.0  squamous             60.0                    3.0   

  Prior_therapy Treatment  
0            no  standard  
1           yes  standard  
2            no  standard  

......继续拟合模型并生成预测:

# Fit Model
coxnet = CoxnetSurvivalAnalysis()
coxnet.fit(X, data_y)    

# What are these predictions?    
preds = coxnet.predict(X)

predsX的记录数相同,但它们的值与data_y中的值不同,即使根据它们所适用的相同数据进行预测也是如此。

print(preds.mean()) 
print(data_y['Survival_in_days'].mean())

输出:

-0.044114643249153422
121.62773722627738

究竟是什么preds?很明显.predict在这里意味着与scikit-learn有很大不同,但我无法弄清楚是什么。 API Reference表示它返回"预测的决策函数,"但是,这是什么意思?如何在给定yhat的{​​{1}}个月内达到预测的估算值?我是生存分析的新手,所以我显然错过了一些东西。

1 个答案:

答案 0 :(得分:3)

我发布了此问题on github,但作者重新命名了问题。

我得到了一些有用的解释<xsl:variable name="X" select="key('firstName', 'John')"/> <xsl:variable name="Y" select="key('lastName', 'Smith')"/> ... select="$X[count(.|$Y) = count($Y)]"/> 输出是什么,但仍然不确定如何达到一组预测的生存时间,这是我真正想要的。这是github线程的一些有用的解释:

predict

-sebp(图书馆作者)

predictions are risk scores on an arbitrary scale, which means you can 
usually only determine the sequence of events, but not their exact time.

-pavopax。

github线程有更多解释,但我真的不能完全遵循它。我需要使用It [predict] returns a type of risk score. Higher value means higher risk of your event (class value = True)...You were probably looking for a predicted time. You can get the predicted survival function with estimator.predict_survival_function as in the example 00 notebook...EDIT: Actually, I’m trying to extract this but it’s been a bit of a pain to munge predict_survival_function,看看我是否可以按predict_cumulative_hazard_function中的行来获得最可能的生存时间的一组预测,这是我真正想要的。

我不会在这里接受这个答案,以防其他人有更好的答案。