PySpark:模拟SQL的UPDATE

时间:2019-08-27 17:33:58

标签: pyspark pyspark-sql

我有两个Spark DataFrame:

trg
+---+-----+---------+
|key|value|     flag|
+---+-----+---------+
|  1|  0.1|unchanged|
|  2|  0.2|unchanged|
|  3|  0.3|unchanged|
+---+-----+---------+
src
+---+-----+-------+-----+
|key|value|   flag|merge|
+---+-----+-------+-----+
|  1| 0.11|changed|    0|
|  2| 0.22|changed|    1|
|  3| 0.33|changed|    0|
+---+-----+-------+-----+

我需要基于trg.value“更新” trg.flagsrc.merge,如以下SQL逻辑所述:

UPDATE trg
    INNER JOIN src ON trg.key = src.key
  SET trg.value = src.value,
      trg.flag = src.flag
  WHERE src.merge = 1;

期望新的trg

+---+-----+---------+
|key|value|     flag|
+---+-----+---------+
|  1| 0.1 |unchanged|
|  2| 0.22|  changed|
|  3| 0.3 |unchanged|
+---+-----+---------+

我尝试使用when()。它适用于flag字段(因为它只能有两个值),但不适用于value字段,因为我不知道如何从相应的行中选择值:

from pyspark.sql.functions import when

trg = spark.createDataFrame(data=[('1', '0.1', 'unchanged'),
                                  ('2', '0.2', 'unchanged'),
                                  ('3', '0.3', 'unchanged')],
                            schema=['key', 'value', 'flag'])
src = spark.createDataFrame(data=[('1', '0.11', 'changed', '0'),
                                  ('2', '0.22', 'changed', '1'),
                                  ('3', '0.33', 'changed', '0')],
                            schema=['key', 'value', 'flag', 'merge'])

new_trg = (trg.alias('trg').join(src.alias('src'), on=['key'], how='inner')
          .select(
              'trg.*',
              when(src.merge == 1, 'changed').otherwise('unchanged').alias('flag'),
              when(src.merge == 1, ???).otherwise(???).alias('value')))

是否还有其他(最好是惯用的)方法将该SQL逻辑转换为PySpark?

3 个答案:

答案 0 :(得分:2)

导入和创建数据集

import pyspark.sql.functions as f

l1 = [(1, 0.1, 'unchanged'), (2, 0.2, 'unchanged'), (3, 0.3, 'unchanged')]
dfl1 = spark.createDataFrame(l1).toDF('key', 'value', 'flag')
dfl1.show()
+---+-----+---------+
|key|value|     flag|
+---+-----+---------+
|  1|  0.1|unchanged|
|  2|  0.2|unchanged|
|  3|  0.3|unchanged|
+---+-----+---------+


l2 = [(1, 0.11, 'changed', 0), (2, 0.22, 'changed', 1), (3, 0.33, 'changed', 0)]
dfl2 = spark.createDataFrame(l2).toDF('key', 'value', 'flag', 'merge')
dfl2.show()
+---+-----+-------+-----+
|key|value|   flag|merge|
+---+-----+-------+-----+
|  1| 0.11|changed|    0|
|  2| 0.22|changed|    1|
|  3| 0.33|changed|    0|
+---+-----+-------+-----+
# filtering upfront for better performance in next join
# dfl2 = dfl2.where(dfl2['merge'] == 1) 

加入数据集

join_cond = [dfl1['key'] == dfl2['key'], dfl2['merge'] == 1]
dfl12 = dfl1.join(dfl2, join_cond, 'left_outer')
dfl12.show()
+---+-----+---------+----+-----+-------+-----+
|key|value|     flag| key|value|   flag|merge|
+---+-----+---------+----+-----+-------+-----+
|  1|  0.1|unchanged|null| null|   null| null|
|  3|  0.3|unchanged|null| null|   null| null|
|  2|  0.2|unchanged|   2| 0.22|changed|    1|
+---+-----+---------+----+-----+-------+-----+

在when函数中使用。如果为空,则使用原始值或使用新值

df = dfl12.withColumn('new_value', f.when(dfl2['value'].isNotNull(), dfl2['value']).otherwise(dfl1['value'])).\
    withColumn('new_flag', f.when(dfl2['flag'].isNotNull(), dfl2['flag']).otherwise(dfl1['flag']))

df.show()
+---+-----+---------+----+-----+-------+-----+---------+---------+
|key|value|     flag| key|value|   flag|merge|new_value| new_flag|
+---+-----+---------+----+-----+-------+-----+---------+---------+
|  1|  0.1|unchanged|null| null|   null| null|      0.1|unchanged|
|  3|  0.3|unchanged|null| null|   null| null|      0.3|unchanged|
|  2|  0.2|unchanged|   2| 0.22|changed|    1|     0.22|  changed|
+---+-----+---------+----+-----+-------+-----+---------+---------+

df.select(dfl1['key'], df['new_value'], df['new_flag']).show()
+---+---------+---------+
|key|new_value| new_flag|
+---+---------+---------+
|  1|      0.1|unchanged|
|  3|      0.3|unchanged|
|  2|     0.22|  changed|
+---+---------+---------+

这是为了理解,您可以将几个步骤组合为一个步骤。

答案 1 :(得分:2)

newdf = (trg.join(src, on=['key'], how='inner')
         .select(trg.key,
                 when( src.merge==1, src.value)
                 .otherwise(trg.value).alias('value'),
                 when( src.merge==1, src.flag)
                 .otherwise(trg.flag).alias('flag')))
newdf.show()

+---+-----+---------+
|key|value|     flag|
+---+-----+---------+
|  1|  0.1|unchanged|
|  2| 0.22|  changed|
|  3|  0.3|unchanged|
+---+-----+---------+

答案 2 :(得分:0)

import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql.functions import when

spark = SparkSession.builder.appName("test").getOrCreate()

data1 = [(1, 0.1, 'unchanged'), (2, 0.2,'unchanged'), (3, 0.3, 'unchanged')]
schema = ['key', 'value', 'flag']
df1 = spark.createDataFrame(data1, schema=schema)
df1.show()

+---+-----+---------+
|key|value|     flag|
+---+-----+---------+
|  1|  0.1|unchanged|
|  2|  0.2|unchanged|
|  3|  0.3|unchanged|
+---+-----+---------+


data2 = [(1, 0.11, 'changed',0), (2, 0.22,'changed',1), (3, 0.33, 'changed',0)]
schema2 = ['key', 'value', 'flag', 'merge']
df2 = spark.createDataFrame(data2, schema=schema2)
df2.show()

+---+-----+-------+-----+
|key|value|   flag|merge|
+---+-----+-------+-----+
|  1| 0.11|changed|    0|
|  2| 0.22|changed|    1|
|  3| 0.33|changed|    0|
+---+-----+-------+-----+

df2 = df2.withColumnRenamed("value", "value1").withColumnRenamed("flag", 'flag1')
mer = df1.join(df2, ['key'], 'inner')
mer = mer.withColumn("temp", when(mer.merge == 1, mer.value1).otherwise(mer.value))
mer = mer.withColumn("temp1", when(mer.merge == 1, 'changed').otherwise('unchanged'))
output = mer.select(mer.key, mer.temp.alias('value'), mer.temp1.alias('flag'))
output.orderBy(output.value.asc()).show()

+---+-----+---------+
|key|value|     flag|
+---+-----+---------+
|  1|  0.1|unchanged|
|  2| 0.22|  changed|
|  3|  0.3|unchanged|
+---+-----+---------+