Spark:如何获得群集点(KMeans)

时间:2016-08-22 16:31:26

标签: python apache-spark k-means

我试图在Spark中检索属于特定群集的数据点。在下面的代码中,数据已经组成,但实际上我获得了预测的聚类。

这是我到目前为止的代码:

import numpy as np
# Example data
flight_routes = np.array([[1,3,2,0],
                          [4,2,1,4],
                          [3,6,2,2],
                          [0,5,2,1]])
flight_routes = sc.parallelize(flight_routes)
model = KMeans.train(rdd=flight_routes, k=500, maxIterations=10)

route_test = np.array([[0,2,3,4]])
test = sc.parallelize(route_test)

prediction = model.predict(test)
cluster_number_predicted = prediction.collect()

print cluster_number_predicted # it returns [100] <-- COOL!!

现在,我想拥有属于群集号100的所有数据点。我如何获得这些数据? 我想要实现的就是给出这个问题的答案:Cluster points after Means (Sklearn)

提前谢谢。

1 个答案:

答案 0 :(得分:0)

如果您同时记录和预测(并且不愿意切换到Spark ML),您可以zip RDD:

predictions_and_values = model.predict(test).zip(test)

然后过滤:

predictions_and_values.filter(lambda x: x[1] == 100)