如何使用线性回归模型进行预测?

时间:2017-06-27 15:43:34

标签: java apache-spark linear-regression apache-spark-ml

我目前正致力于线性回归项目,我需要收集数据,将其放在模型上,然后根据测试数据进行预测。

如果我正确,简单的线性回归可以使用两个变量,X(独立)和Y(依赖)。我有以下数据集,我认为time列为X,value列为Y:

+-----+------+
|value|minute|
+-----+------+
| 5000|   672|
| 6000|   673|
| 7000|   676|
| 8000|   678|
| 9000|   680|
+-----+------+

我不知道如何将此数据集正确地拟合到线性回归模型中。我之前使用k-means工作过,我用它做了什么,就是用矢量形式创建一个features列。我对此数据集做了同样的事情:

VectorAssembler assembler = new VectorAssembler()
                .setInputCols(new String[]{"minute", "value"})
                .setOutputCol("features");

Dataset<Row> vectorData = assembler.transform(dataset);

然后我将其纳入线性回归模型:

LinearRegression lr = new LinearRegression();
LinearRegressionModel model = lr.fit(vectorData);

这是我陷入困境的部分。如何使用此模型进行预测?我想在value等于随机分钟时找到minute的值,例如。 700.

我该怎么做?如何根据随机X值找到Y值的预测/估计?

编辑:线性回归模型是否区分依赖变量和自变量?怎么样?

3 个答案:

答案 0 :(得分:1)

我只是从Spark MLlib开始,特别是线性回归,所以我只能讨论技术问题(不是为什么在机器学习中这样做的原因)。

  

这是我陷入困境的部分。如何使用此模型进行预测?

模型是变换器(如VectorAssembler),它提供了一个与transform运算符非常简单的接口。

  

transform(dataset:Dataset [_]):DataFrame 转换输入数据集。

这是您传递数据集并返回另一个带有prediction列的数据集的位置。这就是训练和预测的一般方法。

以下内容将为您提供输入数据集中功能的预测。

val dataset = ...
model.transform(dataset).select("prediction").show

我强烈建议将Spark MLlib的ML Pipeline功能用于所谓的预测分析工作流程,这使得将原始数据转换为Estimator格式的过程如此之多更愉快。请参阅Machine Learning Library (MLlib) Guide,尤其是ML Pipelines

  

ML Pipelines 提供了一套基于DataFrame构建的统一的高级API,可帮助用户创建和调整实用的机器学习流程。

答案 1 :(得分:1)

感谢@RickMoritz和@JacekLaskowski的反馈,我能够找到解决方案:

LinearRegression确实有X和Y列。 X列是features列,Y列是label列。

因此,在将数据集拟合到LinearRegression模型之前,请务必说明您的labelfeatures列。您可以在定义LinearRegression时设置标签列:

LinearRegression lr = new LinearRegression().setLabelCol(Ycolumn_name);

对于功能列, 确保将X列转换为矢量类型 ,然后您也可以这样做:

LinearRegression lr = new LinearRegression().setFeaturesCol(Xcolumn_name);

一旦你做完了,你就完全了。要获得预测,只需将X值转换为矢量并将其放在LinearRegressionModel的predict()函数上。

答案 2 :(得分:1)

这里是关于线性回归模型的文档

http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html

将您的XTrain,YTrain数据与线性回归模型相匹配。确保XTrain和Y列车是数据帧。

使用pandas将数据转换为数据帧。

现在您可以提供测试数据以预测值

获得最佳估算器使用网格搜索。 http://scikit-learn.org/stable/modules/grid_search.html

相关问题