我创建了一个管道并尝试在Spark中训练Kmean聚类算法,但是它失败了,我无法找到确切的错误。这是代码
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.evaluation.ClusteringEvaluator
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexer, VectorAssembler, Normalizer}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{SQLContext, SparkSession, functions}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, IntegerType}
val df = spark.read.option("header", "false").option("delimiter", " ").
csv("HMP_Dataset/*").
withColumn("Class" , element_at(reverse(split(input_file_name,"/")),2) ).
withColumn("Source" , element_at(reverse(split(input_file_name,"/")),1)).
withColumnRenamed("_c0","X").withColumnRenamed("_c1","Y").
withColumnRenamed("_c2","Z")
val df2 = df.select(
df.columns.map {
case x @ "X" => df(x).cast(DoubleType).as(x)
case y @ "Y" => df(y).cast(DoubleType).as(y)
case z @ "Z" => df(z).cast(DoubleType).as(z)
case other => df(other)
}: _*
)
val indexer = new StringIndexer().setInputCol("Class").setOutputCol("ClassIndex")
val encoder = new OneHotEncoderEstimator().setInputCols(Array("ClassIndex")) .setOutputCols(Array("CategoryVec"))
val assembler = new VectorAssembler().setInputCols(Array("X","Y","Z")).setOutputCol("Features")
val normalizer = new Normalizer().setInputCol("Features").setOutputCol("feature_Norm")
val pipeline = new Pipeline( ).setStages(Array ( indexer , encoder , assembler , normalizer) )
val model = pipeline.fit(df2).transform(df2)
val train = model.drop("X").drop("Y").drop("Z").drop("Class").drop("Source").drop("ClassIndex").drop("Features")
//model.show()
//train.show()
val kmeans = new KMeans().setFeaturesCol("feature_Norm").setK(2).setSeed(1).setMaxIter(100).fit(train).transform(train)
成功创建了train 数据框,但是当我传递给 Kmeans 时,它会引发错误。错误消息是
Failed to execute user defined function($anonfun$4: (struct<X:double,Y:double,Z:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>).
如何解决此问题?
答案 0 :(得分:0)
也许版本库和导入存在问题,在我的笔记本电脑中,代码工作正常。
我将向您展示我的.sbt
和代码产生的输出。
+--------------+-----------------------------------------------------------+----------+
|CategoryVec |feature_Norm |prediction|
+--------------+-----------------------------------------------------------+----------+
|(13,[0],[1.0])|[0.2574383611739353,0.6931032800836721,0.6733003292241385] |1 |
|(13,[0],[1.0])|[0.22614412777205142,0.6989909403863407,0.6784323833161543]|1 |
|(13,[0],[1.0])|[0.24551225268848764,0.675158694893341,0.6956180492840484] |1 |
|(13,[0],[1.0])|[0.2420417625303279,0.7059551407134563,0.6656148469584017] |1 |
|(13,[0],[1.0])|[0.24214029368137852,0.6860641654305725,0.6860641654305725]|1 |
|(13,[0],[1.0])|[0.24214029368137852,0.6860641654305725,0.6860641654305725]|1 |
|(13,[0],[1.0])|[0.2540244987629046,0.683912112053974,0.683912112053974] |1 |
|(13,[0],[1.0])|[0.2388089256503974,0.6766252893427926,0.6965260331469925] |1 |
|(13,[0],[1.0])|[0.2574383611739353,0.6733003292241385,0.6931032800836721] |1 |
|(13,[0],[1.0])|[0.2572366859677566,0.652985433610459,0.7123477457568644] |1 |
+--------------+-----------------------------------------------------------+----------+
+--------------+------------------------------------------------------------+----------+
|CategoryVec |feature_Norm |prediction|
+--------------+------------------------------------------------------------+----------+
|(13,[5],[1.0])|[0.4673452175282961,0.5098311463945049,0.7222607907255486] |0 |
|(13,[5],[1.0])|[0.4673452175282961,0.5098311463945049,0.7222607907255486] |0 |
|(13,[5],[1.0])|[0.46105396573580254,0.48899663032585117,0.7404806116362889]|0 |
|(13,[5],[1.0])|[0.4369231823814617,0.5214889596165833,0.7329034027043874] |0 |
|(13,[5],[1.0])|[0.45146611838648026,0.5078993831847903,0.7336324423780305] |0 |
|(13,[5],[1.0])|[0.4561664027908625,0.5131872031397203,0.7270152044479371] |0 |
|(13,[5],[1.0])|[0.4561664027908625,0.5131872031397203,0.7270152044479371] |0 |
|(13,[5],[1.0])|[0.45789190653985307,0.49951844349802155,0.7354021529276429]|0 |
|(13,[5],[1.0])|[0.4658526940598004,0.4940861906694853,0.7340709118518067] |0 |
|(13,[5],[1.0])|[0.4625915702820905,0.5046453493986442,0.7289321713535972] |0 |
+--------------+------------------------------------------------------------+----------+
build.sbt
scalaVersion := "2.11.10"
// https://mvnrepository.com/artifact/org.apache.spark/spark-mllib
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "2.2.0"
libraryDependencies += "org.apache.spark" %% "spark-core" % "2.2.0"
libraryDependencies += "org.apache.spark" % "spark-sql_2.11" % "2.2.0"
进口
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.feature.{Normalizer, OneHotEncoderEstimator, StringIndexer, VectorAssembler}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.DoubleType