UDF函数内部数据帧的长度

时间:2019-10-11 14:37:54

标签: scala apache-spark apache-spark-sql

我需要编写一个复杂的用户定义函数(UDF),该函数需要多列作为输入。像这样:

val uudf = udf{(val:Int, lag:Int, cumsum_p:Double) => val + lag + cum_p} // actually a more complex function but let's make it simple

第三个参数cumsum_p指示是p的累积和,其中p是计算出的 group 的长度。因为此udf随后将在groupby中使用。

我想出了这个解决方案,这几乎可以:

val uudf = udf{(val:Int, lag:Int, cumsum_p:Double) => val + lag + cum_p}
val w = Window.orderBy($"sale_qty")
df.withColumn("needThat", 
    uudf(col("sale_qty"),
       lead("sale_qty",1).over(w), sum(lit(1/length_group)).over(w)
    )
).show()

问题是,如果我将lit(1/length_group)替换为lit(1/count("sale_qty")),则创建的列现在仅包含1个元素,这会导致错误...

1 个答案:

答案 0 :(得分:1)

您应该首先计算count("sale_qty")

val w = Window.orderBy($"sale_qty")
df
.withColumn("cnt",count($"sale_qty").over()) 
.withColumn("needThat", 
    uudf(col("sale_qty"),
       lead("sale_qty",1).over(w), sum(lit(1)/$"cnt").over(w)
    )
).show()