我有一个简单的csv文件(99观察,附加文件),我想做一个简单的回归并预测我的文件中的字段(能量消耗" y_3")。我在R中使用了线性回归( lm方法在插入符中,代码在最后)和 LinearRegressionWithSGD (最后的代码)在MLlib,Spark中。
但Spark中的RMSE是R的20倍。
所以我对这些结果感到困惑,任何帮助都会受到影响 感谢
CSV:
y_3,x_6,x_7,x_73_1,x_73_2,x_73_3,x_8
2995.3846153846152,17.0,1800.0,0.0,1.0,0.0,12.0
2236.304347826087,17.0,1432.0,1.0,0.0,0.0,12.0
2001.9512195121952,35.0,1432.0,0.0,1.0,0.0,5.0
992.4324324324324,17.0,1430.0,1.0,0.0,0.0,12.0
4386.666666666667,26.0,1430.0,0.0,0.0,1.0,25.0
1335.9036144578313,17.0,1432.0,0.0,1.0,0.0,5.0
1097.560975609756,17.0,1100.0,0.0,1.0,0.0,5.0
3526.6666666666665,26.0,1432.0,0.0,1.0,0.0,12.0
506.8421052631579,17.0,1430.0,1.0,0.0,0.0,5.0
2095.890410958904,35.0,1430.0,1.0,0.0,0.0,12.0
720.0,35.0,1430.0,1.0,0.0,0.0,5.0
2416.5,17.0,1432.0,0.0,0.0,1.0,12.0
3306.6666666666665,35.0,1800.0,0.0,0.0,1.0,12.0
6105.974025974026,35.0,1800.0,1.0,0.0,0.0,25.0
1400.4624277456646,35.0,1800.0,1.0,0.0,0.0,5.0
1414.5454545454545,26.0,1430.0,1.0,0.0,0.0,12.0
5204.68085106383,26.0,1800.0,0.0,0.0,1.0,25.0
1812.2222222222222,17.0,1800.0,1.0,0.0,0.0,12.0
2763.5928143712576,35.0,1100.0,1.0,0.0,0.0,12.0
1192.5991189427314,26.0,1800.0,1.0,0.0,0.0,5.0
5332.075471698113,26.0,1432.0,1.0,0.0,0.0,25.0
3318.157894736842,17.0,1100.0,1.0,0.0,0.0,25.0
1551.063829787234,35.0,1432.0,1.0,0.0,0.0,5.0
7998.75,35.0,1800.0,0.0,1.0,0.0,25.0
1800.0,26.0,1800.0,0.0,1.0,0.0,5.0
1829.0683229813665,17.0,1100.0,0.0,0.0,1.0,12.0
3396.0,35.0,1432.0,1.0,0.0,0.0,12.0
1911.6666666666667,26.0,1100.0,1.0,0.0,0.0,12.0
2160.0,26.0,1430.0,0.0,0.0,1.0,12.0
1777.5,17.0,1430.0,0.0,0.0,1.0,12.0
2089.756097560976,35.0,1800.0,0.0,1.0,0.0,5.0
740.8092485549133,17.0,1100.0,0.0,0.0,1.0,5.0
617.4683544303797,17.0,1100.0,1.0,0.0,0.0,5.0
3787.714285714286,17.0,1800.0,0.0,0.0,1.0,25.0
2450.958904109589,26.0,1800.0,1.0,0.0,0.0,12.0
531.0891089108911,26.0,1430.0,1.0,0.0,0.0,5.0
1108.1739130434783,26.0,1100.0,0.0,0.0,1.0,5.0
6538.378378378378,26.0,1432.0,0.0,1.0,0.0,25.0
6641.0526315789475,26.0,1800.0,0.0,1.0,0.0,25.0
2367.169811320755,26.0,1100.0,0.0,0.0,1.0,12.0
3708.2517482517483,17.0,1100.0,0.0,0.0,1.0,25.0
3253.090909090909,26.0,1100.0,0.0,1.0,0.0,12.0
1824.5454545454545,35.0,1432.0,0.0,0.0,1.0,5.0
2153.1645569620255,17.0,1800.0,0.0,0.0,1.0,12.0
3633.3333333333335,26.0,1800.0,0.0,1.0,0.0,12.0
4645.714285714285,17.0,1100.0,0.0,1.0,0.0,25.0
1575.5294117647059,35.0,1800.0,0.0,0.0,1.0,5.0
1146.6666666666667,17.0,1800.0,0.0,1.0,0.0,5.0
3042.3529411764707,26.0,1432.0,0.0,0.0,1.0,12.0
3841.2,26.0,1430.0,1.0,0.0,0.0,25.0
7625.060240963855,35.0,1432.0,1.0,0.0,0.0,25.0
1120.9580838323354,17.0,1432.0,1.0,0.0,0.0,5.0
936.8421052631579,35.0,1100.0,1.0,0.0,0.0,5.0
2637.0,35.0,1430.0,0.0,0.0,1.0,12.0
3738.4615384615386,35.0,1432.0,0.0,0.0,1.0,12.0
7002.580645161291,35.0,1432.0,0.0,0.0,1.0,25.0
2853.6585365853657,35.0,1100.0,0.0,0.0,1.0,12.0
2553.846153846154,26.0,1432.0,1.0,0.0,0.0,12.0
1153.2558139534883,17.0,1800.0,0.0,0.0,1.0,5.0
3793.170731707317,35.0,1800.0,0.0,1.0,0.0,12.0
1106.1818181818182,26.0,1100.0,0.0,1.0,0.0,5.0
1703.5714285714287,26.0,1432.0,0.0,1.0,0.0,5.0
1269.4736842105262,26.0,1800.0,0.0,0.0,1.0,5.0
2422.0408163265306,17.0,1430.0,1.0,0.0,0.0,25.0
1301.860465116279,35.0,1100.0,0.0,0.0,1.0,5.0
6849.0,26.0,1100.0,0.0,1.0,0.0,25.0
3285.0,17.0,1430.0,0.0,0.0,1.0,25.0
4550.163934426229,17.0,1432.0,0.0,0.0,1.0,25.0
4993.548387096775,17.0,1432.0,0.0,1.0,0.0,25.0
6264.0,35.0,1800.0,0.0,0.0,1.0,25.0
1016.0,17.0,1432.0,0.0,0.0,1.0,5.0
5387.586206896552,17.0,1800.0,0.0,1.0,0.0,25.0
720.0,17.0,1430.0,0.0,0.0,1.0,5.0
7941.818181818182,35.0,1432.0,0.0,1.0,0.0,25.0
4184.1509433962265,17.0,1432.0,1.0,0.0,0.0,25.0
3758.048780487805,35.0,1432.0,0.0,1.0,0.0,12.0
6425.454545454545,35.0,1100.0,0.0,0.0,1.0,25.0
1071.36,26.0,1432.0,1.0,0.0,0.0,5.0
974.7692307692307,17.0,1800.0,1.0,0.0,0.0,5.0
2659.8058252427186,26.0,1800.0,0.0,0.0,1.0,12.0
5697.0,35.0,1430.0,0.0,0.0,1.0,25.0
4967.027027027027,35.0,1430.0,1.0,0.0,0.0,25.0
5734.285714285715,35.0,1100.0,1.0,0.0,0.0,25.0
1529.1640866873065,17.0,1100.0,1.0,0.0,0.0,12.0
2667.6923076923076,17.0,1432.0,0.0,1.0,0.0,12.0
2655.5844155844156,17.0,1100.0,0.0,1.0,0.0,12.0
831.0132158590309,26.0,1100.0,1.0,0.0,0.0,5.0
3635.121951219512,35.0,1100.0,0.0,1.0,0.0,12.0
1080.0,35.0,1430.0,0.0,0.0,1.0,5.0
4947.096774193548,26.0,1100.0,0.0,0.0,1.0,25.0
968.7272727272727,26.0,1430.0,0.0,0.0,1.0,5.0
1626.0,26.0,1432.0,0.0,0.0,1.0,5.0
2788.3229813664598,35.0,1800.0,1.0,0.0,0.0,12.0
1800.0,35.0,1100.0,0.0,1.0,0.0,5.0
5796.923076923077,26.0,1432.0,0.0,0.0,1.0,25.0
5105.142857142857,26.0,1800.0,1.0,0.0,0.0,25.0
4594.782608695652,26.0,1100.0,1.0,0.0,0.0,25.0
3519.4736842105262,17.0,1800.0,1.0,0.0,0.0,25.0
R代码:
library(caret)
library(fscaret)
set.seed(123)
data.main <- read.csv(file="data_2.csv",head=TRUE,sep=",")
split <- createDataPartition(y = data.main$y_3, p = 0.8, list = FALSE)
train <- data.main[split,]
test <- data.main[-split,]
lmCVFit.lm <-train(y_3 ~ .,
data = train,
method = "lm")
p <- predict.train(lmCVFit.lm, test[,-1])
t <- test$y_3
n <- nrow(test)
mse.lm <- MSE(p, t, n)
rmse.lm <- RMSE(p, t, n)
Spark代码:
val datadir = new String("data directory");
val headerAndRows = sc.textFile(datadir + "/data_2.csv");
val header = headerAndRows.first;
val data = headerAndRows.filter(_(0) != header(0));
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
}.cache()
// Split data into training and test
val splits = parsedData.randomSplit(Array(0.8, 0.2), seed = 11L);
val train = splits(0).cache
val test = splits(1).cache
val algorithm = new LinearRegressionWithSGD()
algorithm.setIntercept(true)
algorithm.optimizer
.setNumIterations(100)
.setStepSize(0.5)
.setUpdater(new SquaredL2Updater())
.setRegParam(10.0)
val new_model = algorithm.run(train)
val valuesAndPreds = test.map { point =>
val prediction = new_model.predict(point.features)
val r = (point.label, prediction)
r
}
val residuals = valuesAndPreds.map {case (v, p) => math.pow((v - p), 2)}
val MSE = residuals.mean();
val RMSE = math.pow(MSE, 0.5);