Pyspark:在数据帧的不同组上应用kmeans

时间:2017-11-10 14:01:34

标签: apache-spark group-by pyspark k-means

使用Pyspark我想在数据帧的组中分别应用kmeans,而不是同时对整个数据帧应用kmeans。目前我使用for循环迭代每个组,应用kmeans并将结果附加到另一个表。但是拥有很多团队会耗费大量时间。有人可以帮我吗? 非常感谢!

test%22.txt

2 个答案:

答案 0 :(得分:1)

我想出了使用pandas_udf的解决方案。纯火花或scala解决方案是首选但尚未提供。 假设我的数据是

import pandas as pd
df_pd = pd.DataFrame([['cat1',10.],['cat1',20.],['cat1',11.],['cat1',21.],['cat1',22.],['cat1',9.],['cat2',101.],['cat2',201.],['cat2',111.],['cat2',214.],['cat2',224.],['cat2',99.]],columns=['cat','val'])
df_sprk = spark.createDataFrame(df_pd)

首先解决熊猫问题:

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=2,random_state=0)

def skmean(kmeans,x):
    X = np.array(x)
    kmeans.fit(X)
    return(kmeans.predict(X))

您可以将skmean()应用于熊猫数据框(以确保其正常工作):

df_pd.groupby('cat').apply(lambda x:skmean(kmeans,x)).reset_index()

要将功能应用于pyspark数据框,我们使用pandas_udf。但是首先为输出数据帧定义一个模式:

from pyspark.sql.types import *
schema = StructType(
       [StructField('cat',StringType(),True),
        StructField('clusters',ArrayType(IntegerType()))])

将上面的函数转换为pandas_udf:

from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType  

@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def skmean_udf(df):
    result = pd.DataFrame(
             df.groupby('cat').apply(lambda x: skmean(kmeans,x))
    result.reset_index(inplace=True, drop=False)
    return(result)

您可以按如下方式使用该功能:

df_spark.groupby('cat').apply(skmean_udf).show()

答案 1 :(得分:1)

我想出了第二个解决方案,我认为它比上一个更好。这个想法是将groupby()collect_list()一起使用,并编写一个以列表为输入并生成簇的udf。在另一个解决方案中,我们继续写df_spark

df_flat = df_spark.groupby('cat').agg(F.collect_list('val').alias('val_list'))

现在我们编写udf函数:

import numpy as np
import pyspark.sql.functions as F
from sklearn.cluster import KMeans
from pyspark.sql.types import *
def skmean(x):
    kmeans = KMeans(n_clusters=2, random_state=0)
    X = np.array(x).reshape(-1,1)  
    kmeans.fit(X)
    clusters = kmeans.predict(X).tolist()
    return(clusters)
clustering_udf = F.udf(lambda arr : skmean(arr), ArrayType(IntegerType()))

然后将udf应用于展平的数据框:

df = df_flat.withColumn('clusters', clustering_udf(F.col('val')))

然后,您可以使用F.explode()将列表转换为列。