我正在使用Spark群集,我想通过执行此代码来实现线性回归:
data = sqlContext.read.format("com.databricks.spark.csv") \
.option("header", "true") \
.option("inferSchema", "true") \
.load("/FileStore/tables/w4s3yhez1497323663423/basma.csv/")
data.cache() # Cache data for faster reuse
data.count()
from pyspark.mllib.regression import LabeledPoint
# convenience for specifying schema
data = data.select("creat0", "gfr0m").rdd.map(lambda r: LabeledPoint(r[1], [r[0]])) \
.toDF()
display(data)
from pyspark.ml.feature import VectorAssembler
vecAssembler = VectorAssembler(inputCols=["creat0", "gfr0m"], outputCol="features")
(trainingData, testData) = data.randomSplit([0.7, 0.3], seed=100)
trainingData.cache()
testData.cache()
print "Training Data : ", trainingData.count()
print "Test Data : ", testData.count()
data.collect()
from pyspark.ml.regression import LinearRegression
lr = LinearRegression()
# Fit 2 models, using different regularization parameters
modelA = lr.fit(data, {lr.regParam: 0.0})
modelB = lr.fit(data, {lr.regParam: 100.0})
# Make predictions
predictionsA = modelA.transform(data)
display(predictionsA)
from pyspark.ml.evaluation import RegressionEvaluator
evaluator = RegressionEvaluator(metricName="rmse")
RMSE = evaluator.evaluate(predictionsA)
print("ModelA: Root Mean Squared Error = " + str(RMSE))
# ModelA: Root Mean Squared Error = 128.602026843
predictionsB = modelB.transform(data)
RMSE = evaluator.evaluate(predictionsB)
print("ModelB: Root Mean Squared Error = " + str(RMSE))
# ModelB: Root Mean Squared Error = 129.496300193
# Import numpy, pandas, and ggplot
import numpy as np
from pandas import *
from ggplot import *
But its give me this error:
IllegalArgumentException:u要求失败:列功能必须 是类型org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7但是 实际上是org.apache.spark.mllib.linalg.VectorUDT@f71b0bce。
在Google上搜索此错误后,我找到了一个答案:
使用from pyspark.ml.linalg import Vectors, VectorUDT
而不是
from pyspark.mllib.linalg import Vectors, VectorUDT
或
from pyspark.ml.linalg import VectorUDT
from pyspark.sql.functions import udf
和一个功能:
as_ml = udf(lambda v: v.asML() if v is not None else None, VectorUDT())
使用示例数据:
from pyspark.mllib.linalg import Vectors as MLLibVectors
df = sc.parallelize([
(MLLibVectors.sparse(4, [0, 2], [1, -1]),),
(MLLibVectors.dense([1, 2, 3, 4]),)
]).toDF(["features"])
result = df.withColumn("features", as_ml("features"))
但我仍然有同样的错误:
这里有一些数据:
原因,“weight0”,“dbp0”,“gfr0m” 1, “90”, “10”, “22.72” 5, “54”, “10”, “16.08” 6, “66”, “9”, “25.47” 3, “110”, “11”, “32.95” 5, “62”, “11”, “20.3” 5, “65”, “8”, “28.94” 1, “65”, “8”, “15.88” 5, “96”, “8”, “38.09” 5, “110”, “8”, “41.64” 如图4所示, “68”, “8”, “25.85” 5, “68”, “7”, “37.77” 1, “82”, “9.5”, “16.25” 5, “76”, “10”, “37.55” 5, “56”, “”, “37.06” 1, “93”, “8”, “18.26” 5, “80”, “7.5”, “48.49” 1, “73”, “8”, “38.37” 如图4所示, “76”, “8”, “31.09” 1, “68”, “8”, “39.62” 1, “82”, “8”, “40.08” 1, “76”, “9.5”, “28.2” 5, “81”, “10”, “36.66” 2, “80”, “”, “47.1” 5, “91”, “10”, “16.59” 2, “58”, “8”, “49.22” 1, “76”, “7”, “38.98” , “61”, “8”, “21.8” 5, “50”, “6”, “26.97” 1, “83”, “7”, “27.81” 1, “86”, “8”, “48.62” , “77”, “6”, “46.78” 5, “64”, “6”, “34.17” 5, “58”, “6”, “38.95” 1, “73”, “6”, “7.63” 5, “86”, “8”, “32.46” 1, “50”, “6”, “35.98” 5, “90”, “7”, “32.26” 5, “42”, “7”, “17.3” 1, “88”, “7”, “25.61” 5, “110”, “”, “” 1, “84”, “6”, “31.4” 5, “68”, “8”, “53.25” 1, “96”, “8”, “52.65” 6, “74”, “8”, “40.77” 1, “70”, “9.5”, “22.35” 6, “54”, “8”, “20.16” 1, “52”, “13”, “32.61” , “84”, “8”, “52.98” 5, “90”, “9”, “28.67”
答案 0 :(得分:3)
在这里,您只需要对来自VectorUDT
的{{1}}进行别名:
pyspark.ml
当然,生成的DataFrame from pyspark.mllib.linalg import Vectors as MLLibVectors
from pyspark.ml.linalg import VectorUDT as VectorUDTML
from pyspark.sql.functions import udf
as_ml = udf(lambda v: v.asML() if v is not None else None, VectorUDTML())
df = sc.parallelize([
(MLLibVectors.sparse(4, [0, 2], [1, -1]),),
(MLLibVectors.dense([1, 2, 3, 4]),)
]).toDF(["features"])
result = df.withColumn("features", as_ml("features"))
result.show()
# +--------------------+
# | features|
# +--------------------+
# |(4,[0,2],[1.0,-1.0])|
# | [1.0,2.0,3.0,4.0]|
# +--------------------+
尚未准备好传递给result
,因为它没有标签列但我相信您知道如何处理那个。