Spark 2.0 - 如何获取与Cluster Center关联的群集ID

时间:2017-11-10 20:33:58

标签: scala apache-spark k-means

我想知道与群集中心相关的ID是什么。 model.transform(dataset)会为我的数据点分配预测的群集ID,model.clusterCenters.foreach(println)会打印这些群集中心,但我无法弄清楚如何将群集中心与其ID相关联。

import org.apache.spark.ml.clustering.KMeans

// Loads data.
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")

// Trains a k-means model.
val kmeans = new KMeans().setK(2).setSeed(1L)
val model = kmeans.fit(dataset)
val prediction = model.transform(dataset)

// Shows the result.
println("Cluster Centers: ")
model.clusterCenters.foreach(println)

理想情况下,我想要一个输出,例如:

|I.D     |cluster center |
==========================
|0       |[0.0,...,0.3]  |
|2       |[1.0,...,1.3]  |
|1       |[2.0,...,1.3]  |
|3       |[3.0,...,1.3]  |

在我看来,println订单是按ID排序的。我尝试将model.clusterCenters转换为DF到transform(),但我无法弄清楚如何将Array[org.apache.spark.ml.linalg.Vector]转换为org.apache.spark.sql.Dataset[_]

1 个答案:

答案 0 :(得分:1)

保存数据后,它将写入cluster_id和Cluster_center。你可以读取文件,可以看到所需的输出

    model.save(sc, "/user/hadoop/kmeanModel")
    val parq = sqlContext.read.parquet("/user/hadoop/kmeanModel/data/*")
    parq.collect.foreach(println)