在Spark上训练Kmeans算法失败

时间:2020-04-16 10:07:10

标签: scala dataframe apache-spark pipeline k-means

我创建了一个管道并尝试在Spark中训练Kmean聚类算法,但是它失败了,我无法找到确切的错误。这是代码

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.evaluation.ClusteringEvaluator
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexer, VectorAssembler, Normalizer}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{SQLContext, SparkSession, functions}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, IntegerType}

val df = spark.read.option("header", "false").option("delimiter", " ").
  csv("HMP_Dataset/*").
  withColumn("Class"  ,  element_at(reverse(split(input_file_name,"/")),2)  ).
  withColumn("Source"  ,  element_at(reverse(split(input_file_name,"/")),1)).
  withColumnRenamed("_c0","X").withColumnRenamed("_c1","Y").
  withColumnRenamed("_c2","Z")

val df2 = df.select(
  df.columns.map {
    case x  @ "X" => df(x).cast(DoubleType).as(x)
    case y  @ "Y" => df(y).cast(DoubleType).as(y)
    case z  @ "Z" => df(z).cast(DoubleType).as(z)
    case other         => df(other)
  }: _*
)

val indexer = new StringIndexer().setInputCol("Class").setOutputCol("ClassIndex")
val encoder = new OneHotEncoderEstimator().setInputCols(Array("ClassIndex")) .setOutputCols(Array("CategoryVec"))
val assembler = new VectorAssembler().setInputCols(Array("X","Y","Z")).setOutputCol("Features")
val normalizer = new Normalizer().setInputCol("Features").setOutputCol("feature_Norm")
val pipeline = new Pipeline( ).setStages(Array ( indexer , encoder , assembler , normalizer) )
val model = pipeline.fit(df2).transform(df2)

val train = model.drop("X").drop("Y").drop("Z").drop("Class").drop("Source").drop("ClassIndex").drop("Features")
//model.show()
//train.show()
val kmeans = new KMeans().setFeaturesCol("feature_Norm").setK(2).setSeed(1).setMaxIter(100).fit(train).transform(train)
成功创建了

train 数据框,但是当我传递给 Kmeans 时,它会引发错误。错误消息是

Failed to execute user defined function($anonfun$4: (struct<X:double,Y:double,Z:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>).

如何解决此问题?

1 个答案:

答案 0 :(得分:0)

也许版本库和导入存在问题,在我的笔记本电脑中,代码工作正常。

我将向您展示我的.sbt和代码产生的输出。

+--------------+-----------------------------------------------------------+----------+
|CategoryVec   |feature_Norm                                               |prediction|
+--------------+-----------------------------------------------------------+----------+
|(13,[0],[1.0])|[0.2574383611739353,0.6931032800836721,0.6733003292241385] |1         |
|(13,[0],[1.0])|[0.22614412777205142,0.6989909403863407,0.6784323833161543]|1         |
|(13,[0],[1.0])|[0.24551225268848764,0.675158694893341,0.6956180492840484] |1         |
|(13,[0],[1.0])|[0.2420417625303279,0.7059551407134563,0.6656148469584017] |1         |
|(13,[0],[1.0])|[0.24214029368137852,0.6860641654305725,0.6860641654305725]|1         |
|(13,[0],[1.0])|[0.24214029368137852,0.6860641654305725,0.6860641654305725]|1         |
|(13,[0],[1.0])|[0.2540244987629046,0.683912112053974,0.683912112053974]   |1         |
|(13,[0],[1.0])|[0.2388089256503974,0.6766252893427926,0.6965260331469925] |1         |
|(13,[0],[1.0])|[0.2574383611739353,0.6733003292241385,0.6931032800836721] |1         |
|(13,[0],[1.0])|[0.2572366859677566,0.652985433610459,0.7123477457568644]  |1         |
+--------------+-----------------------------------------------------------+----------+

+--------------+------------------------------------------------------------+----------+
|CategoryVec   |feature_Norm                                                |prediction|
+--------------+------------------------------------------------------------+----------+
|(13,[5],[1.0])|[0.4673452175282961,0.5098311463945049,0.7222607907255486]  |0         |
|(13,[5],[1.0])|[0.4673452175282961,0.5098311463945049,0.7222607907255486]  |0         |
|(13,[5],[1.0])|[0.46105396573580254,0.48899663032585117,0.7404806116362889]|0         |
|(13,[5],[1.0])|[0.4369231823814617,0.5214889596165833,0.7329034027043874]  |0         |
|(13,[5],[1.0])|[0.45146611838648026,0.5078993831847903,0.7336324423780305] |0         |
|(13,[5],[1.0])|[0.4561664027908625,0.5131872031397203,0.7270152044479371]  |0         |
|(13,[5],[1.0])|[0.4561664027908625,0.5131872031397203,0.7270152044479371]  |0         |
|(13,[5],[1.0])|[0.45789190653985307,0.49951844349802155,0.7354021529276429]|0         |
|(13,[5],[1.0])|[0.4658526940598004,0.4940861906694853,0.7340709118518067]  |0         |
|(13,[5],[1.0])|[0.4625915702820905,0.5046453493986442,0.7289321713535972]  |0         |
+--------------+------------------------------------------------------------+----------+

build.sbt

scalaVersion := "2.11.10"

// https://mvnrepository.com/artifact/org.apache.spark/spark-mllib
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "2.2.0"
libraryDependencies += "org.apache.spark" %% "spark-core" % "2.2.0"
libraryDependencies += "org.apache.spark" % "spark-sql_2.11" % "2.2.0"

进口

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.feature.{Normalizer, OneHotEncoderEstimator, StringIndexer, VectorAssembler}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.DoubleType