我有一个这样的数据框
data = [(("ID1", 1, 5)), (("ID1", 2, 6)), (("ID1", 3, 7)),
(("ID1", 4, 4)), (("ID1", 5, 2)), (("ID1", 6, 2)),
(("ID2", 1, 4)), (("ID2", 2, 6)), (("ID2", 3, 1)), (("ID2", 4, 1)), (("ID2", 5, 4))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()
+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1| 1| 5|
|ID1| 2| 6|
|ID1| 3| 7|
|ID1| 4| 4|
|ID1| 5| 2|
|ID1| 6| 2|
|ID2| 1| 4|
|ID2| 2| 6|
|ID2| 3| 1|
|ID2| 4| 1|
|ID2| 5| 4|
+---+----+----+
我想计算每组最后3个元素的最后3个相关性和平均值。
Hence for ID1, for first element (5) - Average = 5, corr = 0
for ID1, for first 2 element (5, 6) - Average = 5.5, corr with colA = 1
for ID1, for first 3 element (5, 6, 7) - Average = 6, corr with colA = 1
for ID1, for elements (6, 7, 4) - Average = 5.66, corr with colA = -0.65
Expected output is like this
+---+----+----+----------+---------+
| ID|colA|colB|corr_last3|avg_last3|
+---+----+----+----------+---------+
|ID1| 1| 5| 0| 5|
|ID1| 2| 6| 1| 5.5|
|ID1| 3| 7| 1| 6|
|ID1| 4| 4| -0.65| 5.66|
|ID1| 5| 2| -0.99| 4.33|
|ID1| 6| 2| -0.86| 2.66|
|ID2| 1| 4| 0| 4|
|ID2| 2| 6| 1| 5|
|ID2| 3| 1| -0.59| 3.66|
|ID2| 4| 1| -0.86| 2.66|
|ID2| 5| 4| 0.86| 2|
+---+----+----+----------+---------+
答案 0 :(得分:3)
您可以使用内置函数avg
和corr
来完成此任务,这里是scala解决方案:
df
.withColumn("indices",row_number().over(Window.partitionBy($"ID").orderBy($"colA")))
.withColumn("corr_last3", when($"indices">1,corr($"indices",$"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow))).otherwise(0.0))
.withColumn("avg_last3", avg($"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow)))
.drop($"indices")
.orderBy($"ID",$"colA")
.show()
给予:
+---+----+----+-------------------+------------------+
| ID|colA|colB| corr_last3| avg_last3|
+---+----+----+-------------------+------------------+
|ID1| 1| 5| 0.0| 5.0|
|ID1| 2| 6| 1.0| 5.5|
|ID1| 3| 7| 1.0| 6.0|
|ID1| 4| 4|-0.6546536707079772| 5.666666666666667|
|ID1| 5| 2|-0.9933992677987828| 4.333333333333333|
|ID1| 6| 2|-0.8660254037844386|2.6666666666666665|
|ID2| 1| 4| 0.0| 4.0|
|ID2| 2| 6| 1.0| 5.0|
|ID2| 3| 1|-0.5960395606792697|3.6666666666666665|
|ID2| 4| 1|-0.8660254037844387|2.6666666666666665|
|ID2| 5| 4| 0.8660254037844387| 2.0|
+---+----+----+-------------------+------------------+
答案 1 :(得分:0)
Pyspark
的答案是这个
from pyspark.sql import Window
from pyspark.sql.functions import rank, corr, when, mean, col, round
df = df\
.withColumn("indices",rank().over(Window.partitionBy("ID").orderBy("colA")))\
.withColumn("corr_last3", when(col("indices") > 1, corr(col("indices"), col("colB"))
.over(Window.partitionBy("ID").orderBy("colA")
.rangeBetween(-2, Window.currentRow))).otherwise(0.0))\
.withColumn("avg_last3", mean(col("colB")).over(Window.partitionBy("ID").orderBy("colA").rangeBetween(-2, Window.currentRow)))\
.drop(col("indices"))\
.orderBy("ID","colA")
df = df.withColumn("corr_last3", round(col("corr_last3"), 3))\
.withColumn("avg_last3", round(col("corr_last3"), 3))
df.show()
+---+----+----+----------+---------+
| ID|colA|colB|corr_last3|avg_last3|
+---+----+----+----------+---------+
|ID1| 1| 5| 0.0| 0.0|
|ID1| 2| 6| 1.0| 1.0|
|ID1| 3| 7| 1.0| 1.0|
|ID1| 4| 4| -0.655| -0.655|
|ID1| 5| 2| -0.993| -0.993|
|ID1| 6| 2| -0.866| -0.866|
|ID2| 1| 4| 0.0| 0.0|
|ID2| 2| 6| 1.0| 1.0|
|ID2| 3| 1| -0.596| -0.596|
|ID2| 4| 1| -0.866| -0.866|
|ID2| 5| 4| 0.866| 0.866|
+---+----+----+----------+---------+