我有两个Spark DataFrame:
trg +---+-----+---------+ |key|value| flag| +---+-----+---------+ | 1| 0.1|unchanged| | 2| 0.2|unchanged| | 3| 0.3|unchanged| +---+-----+---------+
src +---+-----+-------+-----+ |key|value| flag|merge| +---+-----+-------+-----+ | 1| 0.11|changed| 0| | 2| 0.22|changed| 1| | 3| 0.33|changed| 0| +---+-----+-------+-----+
我需要基于trg.value
“更新” trg.flag
和src.merge
,如以下SQL逻辑所述:
UPDATE trg
INNER JOIN src ON trg.key = src.key
SET trg.value = src.value,
trg.flag = src.flag
WHERE src.merge = 1;
期望新的trg
:
+---+-----+---------+ |key|value| flag| +---+-----+---------+ | 1| 0.1 |unchanged| | 2| 0.22| changed| | 3| 0.3 |unchanged| +---+-----+---------+
我尝试使用when()
。它适用于flag
字段(因为它只能有两个值),但不适用于value
字段,因为我不知道如何从相应的行中选择值:
from pyspark.sql.functions import when
trg = spark.createDataFrame(data=[('1', '0.1', 'unchanged'),
('2', '0.2', 'unchanged'),
('3', '0.3', 'unchanged')],
schema=['key', 'value', 'flag'])
src = spark.createDataFrame(data=[('1', '0.11', 'changed', '0'),
('2', '0.22', 'changed', '1'),
('3', '0.33', 'changed', '0')],
schema=['key', 'value', 'flag', 'merge'])
new_trg = (trg.alias('trg').join(src.alias('src'), on=['key'], how='inner')
.select(
'trg.*',
when(src.merge == 1, 'changed').otherwise('unchanged').alias('flag'),
when(src.merge == 1, ???).otherwise(???).alias('value')))
是否还有其他(最好是惯用的)方法将该SQL逻辑转换为PySpark?
答案 0 :(得分:2)
导入和创建数据集
import pyspark.sql.functions as f
l1 = [(1, 0.1, 'unchanged'), (2, 0.2, 'unchanged'), (3, 0.3, 'unchanged')]
dfl1 = spark.createDataFrame(l1).toDF('key', 'value', 'flag')
dfl1.show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.2|unchanged|
| 3| 0.3|unchanged|
+---+-----+---------+
l2 = [(1, 0.11, 'changed', 0), (2, 0.22, 'changed', 1), (3, 0.33, 'changed', 0)]
dfl2 = spark.createDataFrame(l2).toDF('key', 'value', 'flag', 'merge')
dfl2.show()
+---+-----+-------+-----+
|key|value| flag|merge|
+---+-----+-------+-----+
| 1| 0.11|changed| 0|
| 2| 0.22|changed| 1|
| 3| 0.33|changed| 0|
+---+-----+-------+-----+
# filtering upfront for better performance in next join
# dfl2 = dfl2.where(dfl2['merge'] == 1)
加入数据集
join_cond = [dfl1['key'] == dfl2['key'], dfl2['merge'] == 1]
dfl12 = dfl1.join(dfl2, join_cond, 'left_outer')
dfl12.show()
+---+-----+---------+----+-----+-------+-----+
|key|value| flag| key|value| flag|merge|
+---+-----+---------+----+-----+-------+-----+
| 1| 0.1|unchanged|null| null| null| null|
| 3| 0.3|unchanged|null| null| null| null|
| 2| 0.2|unchanged| 2| 0.22|changed| 1|
+---+-----+---------+----+-----+-------+-----+
在when函数中使用。如果为空,则使用原始值或使用新值
df = dfl12.withColumn('new_value', f.when(dfl2['value'].isNotNull(), dfl2['value']).otherwise(dfl1['value'])).\
withColumn('new_flag', f.when(dfl2['flag'].isNotNull(), dfl2['flag']).otherwise(dfl1['flag']))
df.show()
+---+-----+---------+----+-----+-------+-----+---------+---------+
|key|value| flag| key|value| flag|merge|new_value| new_flag|
+---+-----+---------+----+-----+-------+-----+---------+---------+
| 1| 0.1|unchanged|null| null| null| null| 0.1|unchanged|
| 3| 0.3|unchanged|null| null| null| null| 0.3|unchanged|
| 2| 0.2|unchanged| 2| 0.22|changed| 1| 0.22| changed|
+---+-----+---------+----+-----+-------+-----+---------+---------+
df.select(dfl1['key'], df['new_value'], df['new_flag']).show()
+---+---------+---------+
|key|new_value| new_flag|
+---+---------+---------+
| 1| 0.1|unchanged|
| 3| 0.3|unchanged|
| 2| 0.22| changed|
+---+---------+---------+
这是为了理解,您可以将几个步骤组合为一个步骤。
答案 1 :(得分:2)
newdf = (trg.join(src, on=['key'], how='inner')
.select(trg.key,
when( src.merge==1, src.value)
.otherwise(trg.value).alias('value'),
when( src.merge==1, src.flag)
.otherwise(trg.flag).alias('flag')))
newdf.show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.22| changed|
| 3| 0.3|unchanged|
+---+-----+---------+
答案 2 :(得分:0)
import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql.functions import when
spark = SparkSession.builder.appName("test").getOrCreate()
data1 = [(1, 0.1, 'unchanged'), (2, 0.2,'unchanged'), (3, 0.3, 'unchanged')]
schema = ['key', 'value', 'flag']
df1 = spark.createDataFrame(data1, schema=schema)
df1.show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.2|unchanged|
| 3| 0.3|unchanged|
+---+-----+---------+
data2 = [(1, 0.11, 'changed',0), (2, 0.22,'changed',1), (3, 0.33, 'changed',0)]
schema2 = ['key', 'value', 'flag', 'merge']
df2 = spark.createDataFrame(data2, schema=schema2)
df2.show()
+---+-----+-------+-----+
|key|value| flag|merge|
+---+-----+-------+-----+
| 1| 0.11|changed| 0|
| 2| 0.22|changed| 1|
| 3| 0.33|changed| 0|
+---+-----+-------+-----+
df2 = df2.withColumnRenamed("value", "value1").withColumnRenamed("flag", 'flag1')
mer = df1.join(df2, ['key'], 'inner')
mer = mer.withColumn("temp", when(mer.merge == 1, mer.value1).otherwise(mer.value))
mer = mer.withColumn("temp1", when(mer.merge == 1, 'changed').otherwise('unchanged'))
output = mer.select(mer.key, mer.temp.alias('value'), mer.temp1.alias('flag'))
output.orderBy(output.value.asc()).show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.22| changed|
| 3| 0.3|unchanged|
+---+-----+---------+