为什么pyspark的线性回归是错误的?

时间:2019-07-12 18:26:41

标签: pyspark linear-regression

我一直得到错误的答案,所以我尝试了非常非常基本的方法,但仍然是错误的。

input file:
1 1:1
2 1:2
3 1:3
4 1:4
from pyspark.ml.regression import LinearRegression

# Load training data
training = spark.read.format("libsvm").load("stupid.txt")

lr = LinearRegression(maxIter=100, regParam=0.3, loss='squaredError')

# Fit the model
lrModel = lr.fit(training)

# Print the coefficients and intercept for linear regression
print("Coefficients: %s" % str(lrModel.coefficients))
print("Intercept: %s" % str(lrModel.intercept))

# Summarize the model over the training set and print out some metrics
trainingSummary = lrModel.summary
print("numIterations: %d" % trainingSummary.totalIterations)
print("objectiveHistory: %s" % str(trainingSummary.objectiveHistory))
trainingSummary.residuals.show()
print("RMSE: %f" % trainingSummary.rootMeanSquaredError)
print("r2: %f" % trainingSummary.r2)

应该获得系数[1]并截获0。 而是

Coefficients: [0.7884394856681294]
Intercept: 0.52890128583

1 个答案:

答案 0 :(得分:0)

问题似乎出在您使用的regParam参数。如果将其设置为0来运行它,这会导致正常的OLS发生,我们将获得预期的输出:

代码:

from pyspark.ml.regression import LinearRegression

from pyspark.ml.linalg import Vectors
training = spark.createDataFrame([
    (1.0, Vectors.dense(1.0)),
    (2.0, Vectors.dense(2.0)),
    (3.0, Vectors.dense(3.0)),
    (4.0, Vectors.dense(4.0))], ["label", "features"])

lr = LinearRegression(maxIter=100, regParam=0, loss='squaredError')

# Fit the model
lrModel = lr.fit(training)

# Print the coefficients and intercept for linear regression
print("Coefficients: %s" % str(lrModel.coefficients))
print("Intercept: %s" % str(lrModel.intercept))

# Summarize the model over the training set and print out some metrics
trainingSummary = lrModel.summary
print("numIterations: %d" % trainingSummary.totalIterations)
print("objectiveHistory: %s" % str(trainingSummary.objectiveHistory))
trainingSummary.residuals.show()
print("RMSE: %f" % trainingSummary.rootMeanSquaredError)
print("r2: %f" % trainingSummary.r2)

输出:

Coefficients: [1.0]
Intercept: 0.0
numIterations: 1
objectiveHistory: [0.0]
+---------+
|residuals|
+---------+
|      0.0|
|      0.0|
|      0.0|
|      0.0|
+---------+

RMSE: 0.000000
r2: 1.000000

似乎regParam> 0被用作L2正则化术语,并阻止了模型执行正常的OLS处理。