我试图在Spark中检索属于特定群集的数据点。在下面的代码中,数据已经组成,但实际上我获得了预测的聚类。
这是我到目前为止的代码:
import numpy as np
# Example data
flight_routes = np.array([[1,3,2,0],
[4,2,1,4],
[3,6,2,2],
[0,5,2,1]])
flight_routes = sc.parallelize(flight_routes)
model = KMeans.train(rdd=flight_routes, k=500, maxIterations=10)
route_test = np.array([[0,2,3,4]])
test = sc.parallelize(route_test)
prediction = model.predict(test)
cluster_number_predicted = prediction.collect()
print cluster_number_predicted # it returns [100] <-- COOL!!
现在,我想拥有属于群集号100的所有数据点。我如何获得这些数据? 我想要实现的就是给出这个问题的答案:Cluster points after Means (Sklearn)
提前谢谢。
答案 0 :(得分:0)
如果您同时记录和预测(并且不愿意切换到Spark ML),您可以zip
RDD:
predictions_and_values = model.predict(test).zip(test)
然后过滤:
predictions_and_values.filter(lambda x: x[1] == 100)