我可以使用Spark SQL计算折扣的未来累计金额吗?下面是一个使用窗口函数计算未折现暨未来总和的示例,我用折扣总和我的意思进行了硬编码:
from pyspark.sql.window import Window
def undiscountedCummulativeFutureReward(df):
windowSpec = Window \
.partitionBy('user') \
.orderBy('time') \
.rangeBetween(0, Window.unboundedFollowing)
tot_reward = F.sum('reward').over(windowSpec)
df_tot_reward = df.withColumn('undiscounted', tot_reward)
return df_tot_reward
def makeData(spark, gamma=0.5):
data = [{'user': 'bob', 'time': 3, 'reward': 10, 'discounted_cum': 10 + (gamma * 9) + ((gamma ** 2) * 11)},
{'user': 'bob', 'time': 4, 'reward': 9, 'discounted_cum': 9 + gamma * 11},
{'user': 'bob', 'time': 5, 'reward': 11, 'discounted_cum': 11.0},
{'user': 'jo', 'time': 4, 'reward': 6, 'discounted_cum': 6 + gamma * 7},
{'user': 'jo', 'time': 5, 'reward': 7, 'discounted_cum': 7.0},
]
schema = T.StructType([T.StructField('user', T.StringType(), False),
T.StructField('time', T.IntegerType(), False),
T.StructField('reward', T.IntegerType(), False),
T.StructField('discounted_cum', T.FloatType(), False)])
return spark.createDataFrame(data=data, schema=schema)
def main(spark):
df = makeData(spark)
df = undiscountedCummulativeFutureReward(df)
df.orderBy('user', 'time').show()
return df
运行它会得到:
+----+----+------+--------------+------------+
|user|time|reward|discounted_cum|undiscounted|
+----+----+------+--------------+------------+
| bob| 3| 10| 17.25| 30|
| bob| 4| 9| 14.5| 20|
| bob| 5| 11| 11.0| 11|
| jo| 4| 6| 9.5| 13|
| jo| 5| 7| 7.0| 7|
+----+----+------+--------------+------------+
打折的是sum \gamma^k r_k for k=0 to \infinity
我想知道是否可以使用Window函数来计算折扣列,例如引入具有等级的列,带有gamma的文字,将内容相乘-但还是不太清楚-我想我可以用某种方法来实现UDF,但我认为我必须首先collect_as_list
所有用户,返回带有折扣总和的新列表,然后分解该列表。
答案 0 :(得分:0)
假设您从以下DataFrame开始:
df.show()
#+----+----+------+
#|user|time|reward|
#+----+----+------+
#| bob| 3| 10|
#| bob| 4| 9|
#| bob| 5| 11|
#| jo| 4| 6|
#| jo| 5| 7|
#+----+----+------+
您可以在user
列上将此DataFrame与自身连接,并仅保留右表的time
列大于或等于左表的time列的那些行。我们通过为数据帧l
和r
加上别名来简化此操作。
加入后,您可以从左侧表中按user
,time
和reward
进行分组,并从右侧表中汇总奖励列。但是似乎是groupBy
followed by an orderBy
is not guaranteed to maintain that order,所以您应该使用Window
来明确。
from pyspark.sql import Window, functions as f
w = Window.partitionBy("user", "l.time", "l.reward").orderBy("r.time")
df = df.alias("l").join(df.alias("r"), on="user")\
.where("r.time>=l.time")\
.select(
"user",
f.col("l.time").alias("time"),
f.col("l.reward").alias("reward"),
f.collect_list("r.reward").over(w).alias("rewards")
)
df.show()
#+----+----+------+-----------+
#|user|time|reward| rewards|
#+----+----+------+-----------+
#| jo| 4| 6| [6]|
#| jo| 4| 6| [6, 7]|
#| jo| 5| 7| [7]|
#| bob| 3| 10| [10]|
#| bob| 3| 10| [10, 9]|
#| bob| 3| 10|[10, 9, 11]|
#| bob| 4| 9| [9]|
#| bob| 4| 9| [9, 11]|
#| bob| 5| 11| [11]|
#+----+----+------+-----------+
现在,您拥有计算discounted_cum
列所需的所有元素。
您可以使用pyspark.sql.functions.posexplode
爆炸rewards
数组以及列表中的索引。这将为rewards
数组中的每个值添加一个新行。使用distinct
可以删除使用Window
函数(而不是groupBy
)引入的重复项。
我们将其称为索引k
和奖励rk
。现在,您可以使用pyspark.sql.functions.pow
gamma = 0.5
df.select("user", "time", "reward", f.posexplode("rewards").alias("k", "rk"))\
.distinct()\
.withColumn("discounted", f.pow(f.lit(gamma), f.col("k"))*f.col("rk"))\
.groupBy("user", "time")\
.agg(f.first("reward").alias("reward"), f.sum("discounted").alias("discounted_cum"))\
.show()
#+----+----+------+--------------+
#|user|time|reward|discounted_cum|
#+----+----+------+--------------+
#| bob| 3| 10| 17.25|
#| bob| 4| 9| 14.5|
#| bob| 5| 11| 11.0|
#| jo| 4| 6| 9.5|
#| jo| 5| 7| 7.0|
#+----+----+------+--------------+
对于旧版本的spark,在使用row_number()-1
之后,您必须使用k
来获取explode
的值:
df.select("user", "time", "reward", f.explode("rewards").alias("rk"))\
.distinct()\
.withColumn(
"k",
f.row_number().over(Window.partitionBy("user", "time").orderBy("time"))-1
)\
.withColumn("discounted", f.pow(f.lit(gamma), f.col("k"))*f.col("rk"))\
.groupBy("user", "time")\
.agg(f.first("reward").alias("reward"), f.sum("discounted").alias("discounted_cum"))\
.show()
#+----+----+------+--------------+
#|user|time|reward|discounted_cum|
#+----+----+------+--------------+
#| jo| 4| 6| 9.5|
#| jo| 5| 7| 7.0|
#| bob| 3| 10| 17.25|
#| bob| 4| 9| 14.5|
#| bob| 5| 11| 11.0|
#+----+----+------+--------------+