PySpark:我们应该迭代更新数据帧吗?

时间:2018-04-19 09:50:49

标签: python apache-spark optimization pyspark

我的问题有两个部分。 第一个是了解Spark的工作方式,第二个是优化。

我有一个火花数据框,它有多个分类变量。对于这些分类变量中的每一个,我添加一个新列,其中每一行是相应级别的频率。

例如

Date_Built  Square_Footage  Num_Beds    Num_Baths   State   Price     Freq_State
01/01/1920  1700            3           2           NY      700000    4500

这里针对State(分类变量),我添加了一个新变量Freq_State。级别NY在数据集中显示4500次,因此此行在4500列中获得Freq_State

我有多个这样的列,我在其中添加一个相应级别的频率列。

这是我用来实现此目的的代码

def calculate_freq(df, categorical_cols):
    for each_cat_col in categorical_cols:
        _freq = df.select(each_cat_col).groupBy(each_cat_col).count()
        df = df.join(_freq, each_cat_col, "inner")
    return df

第1部分

在这里,正如您所看到的,我正在更新for循环中的数据帧。当我在群集上运行此代码时,这种更新数据帧的方式是否可取?如果它是熊猫数据帧,我不会担心这个问题。但我不确定上下文何时变为火花。

另外,如果我只是在循环中而不是在函数内运行上述过程,它会有所作为吗?

第2部分

有更优化的方法吗?我每次进入循环时都会加入这里?可以避免吗

1 个答案:

答案 0 :(得分:1)

  

有更优化的方法吗?

有哪些可能的选择?

  1. 您可以使用窗口函数

    def calculate_freq(df, categorical_cols):
        for cat_col in categorical_cols:
            w = Window.partitionBy(cat_col)
            df = df.withColumn("{}_freq".format(each_cat_col), count("*").over(w))
        return df
    
    你应该吗?不可以。与join不同,它总是需要对非聚合的DataFrame进行完整的随机播放。

  2. 您可以melt并使用单个本地对象(这要求所有分类列属于同一类型):

    from itertools import groupby
    
    for c in categorical_cols:
         df = df.withColumn(c, df[c].cast("string"))
    
    
    rows = (melt(df, id_vars=[], value_vars=categorical_cols)
            .groupBy("variable", "value").count().collect())
    
    mapping = {k: {x.value: x["count"] for x in v} 
              for k, v in groupby(sorted(rows), lambda x: x.variable)}
    

    并使用udf添加值:

    from pyspark.sql.functions import udf
    
    def get_count(mapping_c):
        @udf("bigint")
        def _(x):
            return mapping_c.get(x)
        return _
    
    
    for c in categorical_cols:
        df = df.withColumn("{}_freq".format(c), get_count(mapping[c])(c))
    
    你应该吗?也许。与迭代连接不同,它只需要一个动作来计算所有统计信息。如果结果很小(预期有分类变量),您可以获得适度的性能提升。

  3. 添加broadcast提示。

    from pyspark.sql.functions import broadcast
    
    def calculate_freq(df, categorical_cols):
        for each_cat_col in categorical_cols:
            _freq = df.select(each_cat_col).groupBy(each_cat_col).count()
        df = df.join(broadcast(_freq), each_cat_col, "inner")
        return df
    

    Spark应该自动广播,所以它不应该改变一件事,但帮助计划者总是更好。

  4.   

    另外,如果我只是在循环中而不是在函数内运行上述过程,它会有所作为吗?

    忽略代码可维护性和可测试性。