如果满足条件,则在Pyspark中合并两行

时间:2019-02-21 16:13:01

标签: python pyspark

我有一个如下的PySpark数据表

shouldMerge | number
true        | 1
true        | 1
true        | 2
false       | 3
false       | 1 

我想将所有列与shouldMerge合并为true并加总数字。

所以最终输出看起来像

shouldMerge | number
true        | 4
false       | 3
false       | 1

我该如何选择应该使用ShouldMerge == true的所有项,将这些数字相加,然后在PySpark中生成新行?

编辑:另一种情况,稍微复杂一点,更接近我要解决的情况,在这里我们只汇总正数:

mergeId     | number
1           | 1
2           | 1
1           | 2
-1          | 3
-1          | 1 

shouldMerge | number
1        | 3
2        | 1
-1       | 3
-1       | 1

3 个答案:

答案 0 :(得分:1)

IIUC,您只想对groupBy个正数执行mergeId

一种方法是过滤DataFrame中的正ID,对它们进行分组,聚合,然后将其与负id(与@shanmuga's answer类似)结合起来。

其他方法是使用when动态创建分组密钥。如果mergeId为正,请使用mergeId进行分组。否则,请使用monotonically_increasing_id以确保该行不会聚合。

这里是一个例子:

import pyspark.sql.functions as f

df.withColumn("uid", f.monotonically_increasing_id())\
    .groupBy(
        f.when(
            f.col("mergeId") > 0, 
            f.col("mergeId")
        ).otherwise(f.col("uid")).alias("mergeKey"), 
        f.col("mergeId")
    )\
    .agg(f.sum("number").alias("number"))\
    .drop("mergeKey")\
    .show()
#+-------+------+
#|mergeId|number|
#+-------+------+
#|     -1|   1.0|
#|      1|   3.0|
#|      2|   1.0|
#|     -1|   3.0|
#+-------+------+

可以通过更改when条件(在本例中为f.col("mergeId") > 0)来满足您的特定要求,从而轻松地将其推广。


说明

首先,我们创建一个临时列uid,它是每行的唯一ID。接下来,我们调用groupBy,如果mergeId是肯定的,请使用mergeId进行分组。否则,我们将uid用作mergeKey。我还将mergeId作为第二个列逐列传递,以保持该列作为输出。

要演示发生了什么,请看一下中间结果:

df.withColumn("uid", f.monotonically_increasing_id())\
    .withColumn(
        "mergeKey",
        f.when(
            f.col("mergeId") > 0, 
            f.col("mergeId")
        ).otherwise(f.col("uid")).alias("mergeKey")
    )\
    .show()
#+-------+------+-----------+-----------+
#|mergeId|number|        uid|   mergeKey|
#+-------+------+-----------+-----------+
#|      1|     1|          0|          1|
#|      2|     1| 8589934592|          2|
#|      1|     2|17179869184|          1|
#|     -1|     3|25769803776|25769803776|
#|     -1|     1|25769803777|25769803777|
#+-------+------+-----------+-----------+

如您所见,mergeKey仍然是负mergeId的唯一值。

从这个中间步骤开始,所需的结果只是琐碎的“和”和“和”,然后删除mergeKey列。

答案 1 :(得分:0)

您将仅需过滤应合并为true并汇总的行。然后将其与其余所有行合并。

import pyspark.sql.functions as functions
df = sqlContext.createDataFrame([
    (True, 1),
    (True, 1),
    (True, 2),
    (False, 3),
    (False, 1),
], ("shouldMerge", "number"))

false_df = df.filter("shouldMerge = false")
true_df = df.filter("shouldMerge = true")
result = true_df.groupBy("shouldMerge")\
    .agg(functions.sum("number").alias("number"))\
    .unionAll(false_df)




df = sqlContext.createDataFrame([
    (1, 1),
    (2, 1),
    (1, 2),
    (-1, 3),
    (-1, 1),
], ("mergeId", "number"))

merge_condition = df["mergeId"] > -1
remaining = ~merge_condition
grouby_field = "mergeId"

false_df = df.filter(remaining)
true_df = df.filter(merge_condition)
result = true_df.groupBy(grouby_field)\
    .agg(functions.sum("number").alias("number"))\
    .unionAll(false_df)

result.show()

答案 2 :(得分:-1)

OP发布的第一个问题。

# Create the DataFrame
valuesCol = [(True,1),(True,1),(True,2),(False,3),(False,1)]
df = sqlContext.createDataFrame(valuesCol,['shouldMerge','number'])
df.show()
+-----------+------+
|shouldMerge|number|
+-----------+------+
|       true|     1|
|       true|     1|
|       true|     2|
|      false|     3|
|      false|     1|
+-----------+------+

# Packages to be imported
from pyspark.sql.window import Window
from pyspark.sql.functions import when, col, lag
# Register the dataframe as a view
df.registerTempTable('table_view')
df=sqlContext.sql(
    'select shouldMerge, number, sum(number) over (partition by shouldMerge) as sum_number from table_view'
)
df = df.withColumn('number',when(col('shouldMerge')==True,col('sum_number')).otherwise(col('number')))
df.show()
+-----------+------+----------+
|shouldMerge|number|sum_number|
+-----------+------+----------+
|       true|     4|         4|
|       true|     4|         4|
|       true|     4|         4|
|      false|     3|         4|
|      false|     1|         4|
+-----------+------+----------+

df = df.drop('sum_number')
my_window = Window.partitionBy().orderBy('shouldMerge')
df = df.withColumn('shouldMerge_lag', lag(col('shouldMerge'),1).over(my_window))
df.show()
+-----------+------+---------------+
|shouldMerge|number|shouldMerge_lag|
+-----------+------+---------------+
|      false|     3|           null|
|      false|     1|          false|
|       true|     4|          false|
|       true|     4|           true|
|       true|     4|           true|
+-----------+------+---------------+

df = df.where(~((col('shouldMerge')==True) & (col('shouldMerge_lag')==True))).drop('shouldMerge_lag')
df.show()
+-----------+------+
|shouldMerge|number|
+-----------+------+
|      false|     3|
|      false|     1|
|       true|     4|
+-----------+------+

针对OP发布的第二个问题

# Create the DataFrame
valuesCol = [(1,2),(1,1),(2,1),(1,2),(-1,3),(-1,1)]
df = sqlContext.createDataFrame(valuesCol,['mergeId','number'])
df.show()
+-------+------+
|mergeId|number|
+-------+------+
|      1|     2|
|      1|     1|
|      2|     1|
|      1|     2|
|     -1|     3|
|     -1|     1|
+-------+------+

# Packages to be imported
from pyspark.sql.window import Window
from pyspark.sql.functions import when, col, lag
# Register the dataframe as a view
df.registerTempTable('table_view')
df=sqlContext.sql(
    'select mergeId, number, sum(number) over (partition by mergeId) as sum_number from table_view'
)
df = df.withColumn('number',when(col('mergeId') > 0,col('sum_number')).otherwise(col('number')))
df.show()
+-------+------+----------+
|mergeId|number|sum_number|
+-------+------+----------+
|      1|     5|         5|
|      1|     5|         5|
|      1|     5|         5|
|      2|     1|         1|
|     -1|     3|         4|
|     -1|     1|         4|
+-------+------+----------+

df = df.drop('sum_number')
my_window = Window.partitionBy('mergeId').orderBy('mergeId')
df = df.withColumn('mergeId_lag', lag(col('mergeId'),1).over(my_window))
df.show()
+-------+------+-----------+
|mergeId|number|mergeId_lag|
+-------+------+-----------+
|      1|     5|       null|
|      1|     5|          1|
|      1|     5|          1|
|      2|     1|       null|
|     -1|     3|       null|
|     -1|     1|         -1|
+-------+------+-----------+

df = df.where(~((col('mergeId') > 0) & (col('mergeId_lag').isNotNull()))).drop('mergeId_lag')
df.show()
+-------+------+
|mergeId|number|
+-------+------+
|      1|     5|
|      2|     1|
|     -1|     3|
|     -1|     1|
+-------+------+

文档: lag()-返回在当前行之前偏移行的值。