pyspark - AFT Survival回归预测无穷大 - 无法理解为什么

时间:2017-08-30 16:12:03

标签: python apache-spark machine-learning pyspark apache-spark-mllib

我正在尝试使用pyspark复制this Kaggle solution

Data(56 kb):示例

lifetime,broken,pressureInd,moistureInd,temperatureInd,team,provider
56,0,92.17885406,104.2302045,96.51715873,TeamA,Provider4
81,1,72.07593772,103.0657014,87.27106218,TeamC,Provider4
60,0,96.27225443,77.80137602,112.1961703,TeamA,Provider1
86,1,94.40646126,108.4936078,72.02537441,TeamC,Provider2
34,0,97.75289859,99.413492,103.7562706,TeamB,Provider1

症结R代码是:

# Choose the dependant variables to be used in the survival regression model.
dependantvars = Surv(maintenance$lifetime, maintenance$broken)

# Create model (use the gaussian method)
survreg = survreg(dependantvars~pressureInd+moistureInd+temperatureInd+team+provider, dist="gaussian",data=maintenance)

我使用RFormula将此代码移植到pyspark,如下所示:

rawDF = spark.read.format("csv").option("header","true").option("inferSchema","true").load("data/maintenance_data.csv")
formula = RFormula(formula="lifetime ~ pressureInd + moistureInd + temperatureInd + team + provider")
output = formula.fit(rawDF).transform(rawDF)
final_df = output.select("label",col("broken").cast("double").alias("censor"),"features")
aft = AFTSurvivalRegression(quantileProbabilities=[0.1,0.7],quantilesCol="quantiles")
model = aft.fit(final_df)
model.transform(final_df).show(truncate=False)

输出:

+-----+------+---------------------------------------------------------+----------+--------------+
|label|censor|features                                                 |prediction|quantiles     |
+-----+------+---------------------------------------------------------+----------+--------------+
|56.0 |0.0   |(8,[0,1,2,4],[92.17885406,104.2302045,96.51715873,1.0])  |Infinity  |[NaN,Infinity]|
|81.0 |1.0   |(8,[0,1,2],[72.07593772,103.0657014,87.27106218])        |Infinity  |[NaN,Infinity]|
|60.0 |0.0   |[96.27225443,77.80137602,112.1961703,0.0,1.0,0.0,1.0,0.0]|Infinity  |[NaN,Infinity]|

对单个要素使用model.predict也会给出无穷大。

你能告诉我我做错了什么吗?为什么我的预测无限?

到目前为止,我能找到的唯一区别是pyspark的AFT生存回归使用Weibull分布,其中kaggle中的R代码使用高斯分布。

0 个答案:

没有答案