我遇到了两个不同的MLLIB实现(org.apache.spark.ml。和org.apache.spark.mllib)和KMeans的问题。我正在使用org.apache.spark.ml的新实现,它正在使用Dataframes,但我正在编写文档以及如何预测集群索引。
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.{Row, SparkSession}
/**
* An example showcasing the use of kMeans
*/
object ExploreKMeans {
// Spark configuration.
// Retrieve sparkContext with spark.sparkContext.
private val spark = SparkSession.builder()
.appName("com.example.ml.exploration.kMeans")
.master("local[*]")
.getOrCreate()
// This import, after the definition of a valid SQLContext defines implicits for converting RDDs to Dataframes over .toDF().
import spark.implicits._
def main(args: Array[String]): Unit = {
val data = spark.sparkContext.parallelize(Array((5.0, 2.0,1.5), (2.0, 2.5,2.3), (1.0, 2.1,4.2), (2.0, 5.5, 8.5)))
val df = data.toDF().map { row =>
val label = row(0).asInstanceOf[Double]
val value1 = row(1).asInstanceOf[Double]
val value2 = row(2).asInstanceOf[Double]
LabeledPoint(label, Vectors.dense(value1,value2))
}
val kmeans = new KMeans().setK(3).setSeed(1L)
val model: KMeansModel = kmeans.fit(df)
// Evaluate clustering by computing Within Set Sum of Squared Errors.
val WSSSE = model.computeCost(df)
println(s"Within Set Sum of Squared Errors = $WSSSE")
// Shows the result.
println("Cluster Centers: ")
model.clusterCenters.foreach(println)
//TODO How to predict cluster index?
//model.predict(???
}
}
如何使用模型预测新值的聚类索引? model.predict函数不可见。这个API真的令人困惑......
答案 0 :(得分:1)
好的,我明白了。现在使用转换方法完成预测:
println("Transform ")
val transformed = model.transform(df)
transformed.collect().foreach(println)
Cluster Centers:
[2.25,1.9]
[5.5,8.5]
[2.1,4.2]
Transform:
[5.0,[2.0,1.5],0]
[2.0,[2.5,2.3],0]
[1.0,[2.1,4.2],2]
[2.0,[5.5,8.5],1]
答案 1 :(得分:1)
嗯,更简单的方法是:
model.summary.predictions.show