将列添加到PySpark数据帧中会根据另外两个列的分组约束列的标准偏差

时间:2019-01-19 00:30:59

标签: dataframe pyspark standard-deviation

假设我们有一个csv文件,该文件已作为PysPark中的数据框导入,如下所示

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.read.csv("file path and name.csv", inferSchema = True, header = True)
df.show()

output

+-----+----+----+
|lable|year|val |
+-----+----+----+
|    A|2003| 5.0|
|    A|2003| 6.0|
|    A|2003| 3.0|
|    A|2004|null|
|    B|2000| 2.0|
|    B|2000|null|
|    B|2009| 1.0|
|    B|2000| 6.0|
|    B|2009| 6.0|
+-----+----+----+

现在,我们要基于两列dfval的分组,向lable添加另一列,其中包含year的标准偏差。因此,输出必须如下所示:

+-----+----+----+-----+
|lable|year|val | std |
+-----+----+----+-----+
|    A|2003| 5.0| 1.53|
|    A|2003| 6.0| 1.53|
|    A|2003| 3.0| 1.53|
|    A|2004|null| null|
|    B|2000| 2.0| 2.83|
|    B|2000|null| 2.83|
|    B|2009| 1.0| 3.54|
|    B|2000| 6.0| 2.83|
|    B|2009| 6.0| 3.54|
+-----+----+----+-----+

我有以下代码适用于小型数据框,但不适用于目前正在使用的非常大的数据框(约4000万行)。

import pyspark.sql.functions as f    
a = df.groupby('lable','year').agg(f.round(f.stddev("val"),2).alias('std'))
df = df.join(a, on = ['lable', 'year'], how = 'inner')

在大型数据框上运行后,出现Py4JJavaError Traceback (most recent call last)错误。

有人知道其他方法吗?希望您的方式对我的数据集有效。

我正在使用python3.7.1pyspark2.4jupyter4.4.0

1 个答案:

答案 0 :(得分:0)

数据帧上的联接导致执行器之间的大量数据混洗。在您的情况下,您可以不使用联接。 使用窗口规范按“标签”和“年份”对数据进行分区,然后在窗口上进行汇总。

from pyspark.sql.window import *

windowSpec = Window.partitionBy('lable','year')\
                   .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

df = df.withColumn("std", f.round(f.stddev("val").over(windowSpec), 2))