我一直得到错误的答案,所以我尝试了非常非常基本的方法,但仍然是错误的。
input file:
1 1:1
2 1:2
3 1:3
4 1:4
from pyspark.ml.regression import LinearRegression
# Load training data
training = spark.read.format("libsvm").load("stupid.txt")
lr = LinearRegression(maxIter=100, regParam=0.3, loss='squaredError')
# Fit the model
lrModel = lr.fit(training)
# Print the coefficients and intercept for linear regression
print("Coefficients: %s" % str(lrModel.coefficients))
print("Intercept: %s" % str(lrModel.intercept))
# Summarize the model over the training set and print out some metrics
trainingSummary = lrModel.summary
print("numIterations: %d" % trainingSummary.totalIterations)
print("objectiveHistory: %s" % str(trainingSummary.objectiveHistory))
trainingSummary.residuals.show()
print("RMSE: %f" % trainingSummary.rootMeanSquaredError)
print("r2: %f" % trainingSummary.r2)
应该获得系数[1]并截获0。 而是
Coefficients: [0.7884394856681294]
Intercept: 0.52890128583
答案 0 :(得分:0)
问题似乎出在您使用的regParam参数。如果将其设置为0来运行它,这会导致正常的OLS发生,我们将获得预期的输出:
代码:
from pyspark.ml.regression import LinearRegression
from pyspark.ml.linalg import Vectors
training = spark.createDataFrame([
(1.0, Vectors.dense(1.0)),
(2.0, Vectors.dense(2.0)),
(3.0, Vectors.dense(3.0)),
(4.0, Vectors.dense(4.0))], ["label", "features"])
lr = LinearRegression(maxIter=100, regParam=0, loss='squaredError')
# Fit the model
lrModel = lr.fit(training)
# Print the coefficients and intercept for linear regression
print("Coefficients: %s" % str(lrModel.coefficients))
print("Intercept: %s" % str(lrModel.intercept))
# Summarize the model over the training set and print out some metrics
trainingSummary = lrModel.summary
print("numIterations: %d" % trainingSummary.totalIterations)
print("objectiveHistory: %s" % str(trainingSummary.objectiveHistory))
trainingSummary.residuals.show()
print("RMSE: %f" % trainingSummary.rootMeanSquaredError)
print("r2: %f" % trainingSummary.r2)
输出:
Coefficients: [1.0]
Intercept: 0.0
numIterations: 1
objectiveHistory: [0.0]
+---------+
|residuals|
+---------+
| 0.0|
| 0.0|
| 0.0|
| 0.0|
+---------+
RMSE: 0.000000
r2: 1.000000
似乎regParam> 0被用作L2正则化术语,并阻止了模型执行正常的OLS处理。