来自https://spark.apache.org/docs/2.2.0/ml-clustering.html#k-means
我知道在kmModel.transform(df)
之后,输出数据帧中有一个prediction
列,指出记录/点属于哪一列。
但是,我还想知道每个记录/点如何偏离质心,因此我知道该簇中的哪些点是典型的,以及簇之间可能存在什么
我该怎么办?默认情况下,该软件包似乎未实现
谢谢!
答案 0 :(得分:1)
假设我们具有以下示例数据和kmeans模型:
from pyspark.ml.linalg import Vectors
from pyspark.ml.clustering import KMeans
import pyspark.sql.functions as F
data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
(Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),),
(Vectors.dense([10.0, 1.5]),), (Vectors.dense([11, 0.0]),) ]
df = spark.createDataFrame(data, ["features"])
n_centres = 2
kmeans = KMeans().setK(n_centres).setSeed(1)
kmModel = kmeans.fit(df)
df_pred = kmModel.transform(df)
df_pred.show()
+----------+----------+
| features|prediction|
+----------+----------+
| [0.0,0.0]| 1|
| [1.0,1.0]| 1|
| [9.0,8.0]| 0|
| [8.0,9.0]| 0|
|[10.0,1.5]| 0|
|[11.0,0.0]| 0|
+----------+----------+
现在,让我们添加一个包含中心坐标的列:
l_clusters = kmModel.clusterCenters()
# Let's convert the list of centers to a dict, each center is a list of float
d_clusters = {int(i):[float(l_clusters[i][j]) for j in range(len(l_clusters[i]))]
for i in range(len(l_clusters))}
# Let's create a dataframe containing the centers and their coordinates
df_centers = spark.sparkContext.parallelize([(k,)+(v,) for k,v in
d_clusters.items()]).toDF(['prediction','center'])
df_pred = df_pred.withColumn('prediction',F.col('prediction').cast(IntegerType()))
df_pred = df_pred.join(df_centers,on='prediction',how='left')
df_pred.show()
+----------+----------+------------+
|prediction| features| center|
+----------+----------+------------+
| 0| [8.0,9.0]|[9.5, 4.625]|
| 0|[10.0,1.5]|[9.5, 4.625]|
| 0| [9.0,8.0]|[9.5, 4.625]|
| 0|[11.0,0.0]|[9.5, 4.625]|
| 1| [1.0,1.0]| [0.5, 0.5]|
| 1| [0.0,0.0]| [0.5, 0.5]|
+----------+----------+------------+
最后,我们可以使用udf来计算列特征和中心坐标之间的距离:
get_dist = F.udf(lambda features, center :
float(features.squared_distance(center)),FloatType())
df_pred = df_pred.withColumn('dist',get_dist(F.col('features'),F.col('center')))
df_pred.show()
+----------+----------+------------+---------+
|prediction| features| center| dist|
+----------+----------+------------+---------+
| 0|[11.0,0.0]|[9.5, 4.625]|23.640625|
| 0| [9.0,8.0]|[9.5, 4.625]|11.640625|
| 0| [8.0,9.0]|[9.5, 4.625]|21.390625|
| 0|[10.0,1.5]|[9.5, 4.625]|10.015625|
| 1| [1.0,1.0]| [0.5, 0.5]| 0.5|
| 1| [0.0,0.0]| [0.5, 0.5]| 0.5|
+----------+----------+------------+---------+