Spark在不同的列上多次连接同一数据集

时间:2018-11-18 05:30:44

标签: apache-spark apache-spark-sql

我有以下两个数据集。


code,name
IN,India
US,United States
UK,United Kingdom
SG,Singapore 

id,name,code1,code2,code3
1,abc,UK,SG,US
2,efg,SG,UK,US

我们可以将code1,code2和code3与第一个数据集结合起来并获取每一列的名称吗?


id,name,code1desc,code2desc,code3desc
1,abc,United Kingdom,Singapore,United States
2,efg,Singapore,United Kingdom,United States

第一列连接正常,但是第二列失败。

Dataset<Row> code1 = people.join(countries, people.col("code1").equalTo(countries.col("code")),"left_outer").withColumnRenamed("name","code1desc");
    code1.show();

以下代码失败:

Dataset<Row> code2 = code1.join(countries, code1.col("code2").equalTo(countries.col("code")),"left_outer");
    code2.show();

2 个答案:

答案 0 :(得分:0)

对于每个人的“ code [i]”列,需要与国家/地区联系,可以在Scala上循环进行:

// data 
val countries = List(
  ("IN", "India"),
  ("US", "United States"),
  ("UK", "United Kingdom"),
  ("SG", "Singapore")
).toDF("code", "name")

val people = List(
  (1, "abc", "UK", "SG", "US"),
  (2, "efg", "SG", "UK", "US")
).toDF("id", "name", "code1", "code2", "code3")

// action
val countryColumns = List("code1", "code2", "code3")
val result = countryColumns.foldLeft(people)((people, column) =>
  people.alias("p")
    .join(countries.withColumnRenamed("name", column + "desc").alias("c"),
      col("p." + column) === $"c.code",
      "left_outer")
    .drop(column, "code")
)

结果是:

+---+----+--------------+--------------+-------------+
|id |name|code1desc     |code2desc     |code3desc    |
+---+----+--------------+--------------+-------------+
|1  |abc |United Kingdom|Singapore     |United States|
|2  |efg |Singapore     |United Kingdom|United States|
+---+----+--------------+--------------+-------------+

注意:如果“国家/地区”数据帧较小,则可以使用广播联接来获得更好的性能。

答案 1 :(得分:0)

您可以使用udf,前提是您的国家/地区代码数据帧足够小。首先,我们将代码收集到一个映射中,然后在每个代码列上应用udf。

code_df是您的country_code数据框,而data_df是您的数据。

import org.apache.spark.sql.functions._

val mapcode = code_df.rdd.keyBy(row => row(0)).collectAsMap()
println("Showing 10 rows of mapcode")

for ((k,v) <- mapcode) {
  printf("key: %s, value: %s\n", k, v)
}


def getCode( code: String ) : String = {
  val desc = mapcode(code).getAs[String](1)
  return desc
}

val getcode_udf = udf(getCode _)

val newdatadf = data_df.withColumn("code1desc", getcode_udf($"code1"))
  .withColumn("code2desc", getcode_udf($"code2"))
  .withColumn("code3desc", getcode_udf($"code3"))

println("Showing 10 rows of final result")
newdatadf.show(10, truncate = false)

以下是结果:

Showing 10 rows of mapcode
key: IN, value: [IN,India]
key: SG, value: [SG,Singapore]
key: UK, value: [UK,United Kingdom]
key: US, value: [US,United States]
Showing 10 rows of final result
+---+----+-----+-----+-----+--------------+--------------+-------------+
|id |name|code1|code2|code3|code1desc     |code2desc     |code3desc    |
+---+----+-----+-----+-----+--------------+--------------+-------------+
|1  |abc |UK   |SG   |US   |United Kingdom|Singapore     |United States|
|2  |efg |SG   |UK   |US   |Singapore     |United Kingdom|United States|
+---+----+-----+-----+-----+--------------+--------------+-------------+