从spark ml文档中获取直接示例。
training = spark.createDataFrame([
(1.218, 1.0, Vectors.dense(1.560, -0.605)),
(2.949, 0.0, Vectors.dense(0.346, 2.158)),
(3.627, 0.0, Vectors.dense(1.380, 0.231)),
(0.273, 1.0, Vectors.dense(0.520, 1.151)),
(4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor",
"features"])
quantileProbabilities = [0.3, 0.6]
aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities,
quantilesCol="quantiles")
#aft = AFTSurvivalRegression()
model = aft.fit(training)
# Print the coefficients, intercept and scale parameter for AFT survival regression
print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
print("Scale: " + str(model.scale))
model.transform(training).show(truncate=False)
结果是:
Coefficients: [-0.496304411053,0.198452172529]
Intercept: 2.6380898963056327
Scale: 1.5472363533632303
+-----+------+--------------+------------------+
|label|censor|features |prediction |
+-----+------+--------------+------------------+
|1.218|1.0 |[1.56,-0.605] |5.718985621018951 |
|2.949|0.0 |[0.346,2.158] |18.07678210850554 |
|3.627|0.0 |[1.38,0.231] |7.381908879359964 |
|0.273|1.0 |[0.52,1.151] |13.577717814884505|
|4.199|0.0 |[0.795,-0.226]|9.013087597344805 |
+-----+------+--------------+------------------+
但是如果我们将所有标签的值更改为标签+ 20. as:
training = spark.createDataFrame([
(21.218, 1.0, Vectors.dense(1.560, -0.605)),
(22.949, 0.0, Vectors.dense(0.346, 2.158)),
(23.627, 0.0, Vectors.dense(1.380, 0.231)),
(20.273, 1.0, Vectors.dense(0.520, 1.151)),
(24.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor",
"features"])
quantileProbabilities = [0.3, 0.6]
aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities,
quantilesCol="quantiles")
#aft = AFTSurvivalRegression()
model = aft.fit(training)
# Print the coefficients, intercept and scale parameter for AFT survival regression
print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
print("Scale: " + str(model.scale))
model.transform(training).show(truncate=False)
结果变为:
Coefficients: [23.9932020748,3.18105314757]
Intercept: 7.35052273751137
Scale: 7698609960.724161
+------+------+--------------+---------------------+---------+
|label |censor|features |prediction |quantiles|
+------+------+--------------+---------------------+---------+
|21.218|1.0 |[1.56,-0.605] |4.0912442688237169E18|[0.0,0.0]|
|22.949|0.0 |[0.346,2.158] |6.011158613411288E9 |[0.0,0.0]|
|23.627|0.0 |[1.38,0.231] |7.7835948690311181E17|[0.0,0.0]|
|20.273|1.0 |[0.52,1.151] |1.5880852723124176E10|[0.0,0.0]|
|24.199|0.0 |[0.795,-0.226]|1.4590190884193677E11|[0.0,0.0]|
+------+------+--------------+---------------------+---------+
有人可以在预测中解释这种指数爆炸,根据我的理解,AFT中的预测是预测失败事件将发生的时间,而不能理解为什么它会在指数值的基础上发生变化。
答案 0 :(得分:0)
以下是我使用Spark2.1运行第二个示例时得到的结果:
Coefficients: [-0.065814695216,0.00326705958509]
Intercept: 3.29140205698
Scale: 0.109856123692
+------+------+--------------+------------------+---------------------------------------+
|label |censor|features |prediction |quantiles |
+------+------+--------------+------------------+---------------------------------------+
|21.218|1.0 |[1.56,-0.605] |24.20972861807431 |[21.61744311047112,23.97833624826161] |
|22.949|0.0 |[0.346,2.158] |26.461225875981274|[23.6278586196251,26.208314087493847] |
|23.627|0.0 |[1.38,0.231] |24.565240805031486|[21.93488840685864,24.330450511651154] |
|20.273|1.0 |[0.52,1.151] |26.074003958175602|[23.282098949562453,25.82479316934075] |
|24.199|0.0 |[0.795,-0.226]|25.491396901107066|[22.761875236582235,25.247754569057975]|
+------+------+--------------+------------------+---------------------------------------+
模型的ParamMap
是:
aft.extractParamMap()
{Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='censorCol', doc='censor column name. The value of this column could be 0 or 1. If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored.'): 'censor',
Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='maxIter', doc='max number of iterations (>= 0).'): 100,
Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='fitIntercept', doc='whether to fit an intercept term.'): True,
Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='aggregationDepth', doc='suggested depth for treeAggregate (>= 2).'): 2,
Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='labelCol', doc='label column name.'): 'label',
Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='featuresCol', doc='features column name.'): 'features',
Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='quantilesCol', doc='quantiles column name. This column will output quantiles of corresponding quantileProbabilities if it is set.'): 'quantiles',
Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='tol', doc='the convergence tolerance for iterative algorithms (>= 0).'): 1e-06,
Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='quantileProbabilities', doc='quantile probabilities array. Values of the quantile probabilities array should be in the range (0, 1) and the array should be non-empty.'): [0.3,
0.6],
Param(parent=u'AFTSurvivalRegression_4a8b957cf888792bb1b8', name='predictionCol', doc='prediction column name.'): 'prediction'}
你能检查收敛容差,最大迭代次数以及是否适合截距项吗?