现在我有这样的数据:
+----+----+
|col1| d|
+----+----+
| A| 4|
| A| 10|
| A| 3|
| B| 3|
| B| 6|
| B| 4|
| B| 5.5|
| B| 13|
+----+----+
col1是StringType,d是TimestampType,这里我改用DoubleType。 我想基于条件元组生成数据。 给定一个元组[(A,3.5),(A,8),(B,3.5),(B,10)] 我想要像这样的结果
+----+---+
|col1| d|
+----+---+
| A| 4|
| A| 10|
| B| 4|
| B| 13|
+----+---+
也就是说,对于元组中的每个元素,我们从pyspark数据框中选择d大于元组编号且col1等于元组字符串的前1行。 我已经写的是:
df_res=spark_empty_dataframe
for (x,y) in tuples:
dft=df.filter(df.col1==x).filter(df.d>y).limit(1)
df_res=df_res.union(dft)
但是我认为这可能存在效率问题,我不知道我是否正确。
答案 0 :(得分:2)
一种避免循环的可能方法是根据输入的元组创建一个数据框:
t = [('A',3.5),('A',8),('B',3.5),('B',10)]
ref=spark.createDataFrame([(i[0],float(i[1])) for i in t],("col1_y","d_y"))
然后我们可以在条件上加入输入数据帧(df
)上,然后对元组的键和值进行分组,将其重复以获取每个组的第一个值,然后删除多余的列: / p>
(df.join(ref,(df.col1==ref.col1_y)&(df.d>ref.d_y),how='inner').orderBy("col1","d")
.groupBy("col1_y","d_y").agg(F.first("col1").alias("col1"),F.first("d").alias("d"))
.drop("col1_y","d_y")).show()
+----+----+
|col1| d|
+----+----+
| A|10.0|
| A| 4.0|
| B| 4.0|
| B|13.0|
+----+----+
请注意,如果数据框的顺序很重要,则可以使用monotonically_increasing_id
分配索引列,并将其包括在聚合中,然后按索引列进行排序。
用另一种方法代替订购,直接用first
获取min
:
(df.join(ref,(df.col1==ref.col1_y)&(df.d>ref.d_y),how='inner')
.groupBy("col1_y","d_y").agg(F.min("col1").alias("col1"),F.min("d").alias("d"))
.drop("col1_y","d_y")).show()
+----+----+
|col1| d|
+----+----+
| B| 4.0|
| B|13.0|
| A| 4.0|
| A|10.0|
+----+----+