PySpark从collect_list数组应用UDF指数加权平均值

时间:2018-05-27 10:12:22

标签: python-3.x pyspark user-defined-functions window-functions weighted-average

我最终希望恢复与Pyspark SPARK-22239中详述的功能类似的功能,这将启用使用Pandas用户定义函数的窗口函数。

具体来说,我正在执行基础数字观测的基于时间戳的窗口化,然后计算每个窗口的指数加权平均值。我有一个工作方法,但我担心它可能效率低下,而且我可能会忽略一个更优化的解决方案。

我通过使用collect_list来获取与数据框中每一行的正确时间窗口相对应的数值数组,然后应用UDF来计算指数加权平均值,从而解决了这个问题。每个阵列。

from pyspark.sql.functions import udf, collect_list
from pyspark.sql.types import DoubleType
from pyspark.sql.window import Window


# collect the relevant set of prices for the moving average per row
def mins(t_mins):
    """
    Utility function converting time in mins to time in secs.
    """
    return 60 * t_mins
w = Window.orderBy('date').rangeBetween(-mins(30), 0)
df = df.withColumn('windowed_price', collect_list('price').over(window))

# compute the exponential weighted mean from each array of prices
@udf(DoubleType())
def arr_to_ewm(arr):
    """
    Computes exponential weighted mean per row from array of relevant time points.
    """
    series = pd.Series(arr)
    ewm = series.ewm(alpha=0.5).mean().iloc[-1]
    # make sure return type is python primitive instead of Numpy dtype
    return float(ewm)
df = df.withColumn('price_ema_30mins', arr_to_ewm(df.windowed_price))

上述方法有效,但据我所知,collect_list和udf的计算成本都很高。是否有更有效的方法在Pyspark中执行此计算?

0 个答案:

没有答案