如何限制Spark中的预测范围?

时间:2017-07-31 06:44:44

标签: scala apache-spark machine-learning

我在星云范围为1-5的数据集中使用Spark ml中的线性回归进行预测:

val lr = new LinearRegression()
  .setMaxIter(10)
  .setRegParam(0.3)
  .setElasticNetParam(0.8)
  .setFeaturesCol(featureCol).setLabelCol(labelCol)

// Fit the model
val lrModel = lr.fit(dataFrame)
val result = lrModel.transform(data)
result.show()

但有些预测> 5:

+--------------------+-------+-----------+---+------------------+
|   topicDistribution|user_id|business_id|  s|        prediction|
+--------------------+-------+-----------+---+------------------+
|[1.0,2.0,3.0,4.0,...|   user|       item|  0|               0.0|
|[0.01514119038647...|      2|          1|  4|4.3475413742362665|
|[0.03940825720524...|      2|          4|  3| 6.916754074011433|
|[0.01514116632977...|      2|          1|  4| 4.245671097612515|
|[0.01786143737009...|      2|          5|  5| 4.753807934900515|
|[0.03943774853904...|      2|          4|  3| 6.973022108753978|
|[0.04868600587994...|      3|          2|  4| 3.648043391726578|
|[0.01515983372328...|      2|          2|  4| 4.246801262511743|
|[0.01786135762750...|      2|          1|  5| 4.753905610858851|
|[0.03940799263407...|      2|          4|  3| 6.970579591530296|
|[0.04868653016151...|      3|          2|  4|3.6480609281936154|
+--------------------+-------+-----------+---+------------------+

如何限制范围[1,5]的预测? 或者将预测转换为[1,5]的方法。

2 个答案:

答案 0 :(得分:0)

我使用Spark SQL过滤和更改值:

GET

有人可能有更好的解决方案。

答案 1 :(得分:0)

val filter1 = ss.filter(" prediction")
                .toDF("topicDistribution","user_id","business_id","s",col)

这是你可以尝试的方式。