使用array_contains()方法在Scala中连接数据

时间:2018-02-21 19:03:22

标签: scala apache-spark-sql spark-dataframe

我在Spark环境中的Scala中有以下数据 -

val abc = Seq(
  (Array("A"),0.1),
  (Array("B"),0.11),
  (Array("C"),0.12),
  (Array("A","B"),0.24),
  (Array("A","C"),0.27),
  (Array("B","C"),0.30),
  (Array("A","B","C"),0.4)
).toDF("channel_set", "rate")

abc.show(false)
abc.createOrReplaceTempView("abc")

val df = abc.withColumn("totalChannels",size(col("channel_set"))).toDF()
df.show()

scala> df.show
+-----------+----+-------------+
|channel_set|rate|totalChannels|
+-----------+----+-------------+
|        [A]| 0.1|            1|
|        [B]|0.11|            1|
|        [C]|0.12|            1|
|     [A, B]|0.24|            2|
|     [A, C]|0.27|            2|
|     [B, C]| 0.3|            2|
|  [A, B, C]| 0.4|            3|
+-----------+----+-------------+



val oneChannelDF = df.filter($"totalChannels" === 1)
oneChannelDF.show()
oneChannelDF.createOrReplaceTempView("oneChannelDF")

+-----------+----+-------------+
|channel_set|rate|totalChannels|
+-----------+----+-------------+
|        [A]| 0.1|            1|
|        [B]|0.11|            1|
|        [C]|0.12|            1|
+-----------+----+-------------+


val twoChannelDF = df.filter($"totalChannels" === 2)
twoChannelDF.show()
twoChannelDF.createOrReplaceTempView("twoChannelDF")

+-----------+----+-------------+
|channel_set|rate|totalChannels|
+-----------+----+-------------+
|     [A, B]|0.24|            2|
|     [A, C]|0.27|            2|
|     [B, C]| 0.3|            2|
+-----------+----+-------------+

我想加入oneChannel和twoChannel数据帧,以便我可以看到我的结果数据如下 -

+-----------+----+-------------+------------+-------+
|channel_set|rate|totalChannels|channel_set | rate  |
+-----------+----+-------------+------------+-------+
|        [A]| 0.1|            1|     [A,B]  |  0.24 |
|        [A]| 0.1|            1|     [A,C]  |  0.27 |
|        [B]|0.11|            1|     [A,B]  |  0.24 |
|        [B]|0.11|            1|     [B,C]  |  0.30 |
|        [C]|0.12|            1|     [A,C]  |  0.27 |
|        [C]|0.12|            1|     [B,C]  |  0.30 |
+-----------+----+-------------+------------+-------+

基本上我需要所有行,其中oneChannel数据帧中的记录存在于twoChannel数据帧中。

我试过了 -

spark.sql("""select * from oneChannelDF one inner join twoChannelDF two on array_contains(one.channel_set,two.channel_set)""").show()

但是,我正面临这个错误 -

org.apache.spark.sql.AnalysisException: cannot resolve 'array_contains(one.`channel_set`, two.`channel_set`)' due to data type mismatch: Arguments must be an array followed by a value of same type as the array members; line 1 pos 62;

1 个答案:

答案 0 :(得分:1)

我想我弄明白了这个错误。我需要将一个成员作为参数传递给array_contains()方法。由于oneChannelDF的channel_set列中每个元素的大小都是1,因此下面的代码会获得正确的数据帧。

scala> spark.sql("""select * from oneChannelDF one inner join twoChannelDF two where array_contains(two.channel_set,one.channel_set[0])""").show()
+-----------+----+-------------+-----------+----+-------------+
|channel_set|rate|totalChannels|channel_set|rate|totalChannels|
+-----------+----+-------------+-----------+----+-------------+
|        [A]| 0.1|            1|     [A, B]|0.24|            2|
|        [A]| 0.1|            1|     [A, C]|0.27|            2|
|        [B]|0.11|            1|     [A, B]|0.24|            2|
|        [B]|0.11|            1|     [B, C]| 0.3|            2|
|        [C]|0.12|            1|     [A, C]|0.27|            2|
|        [C]|0.12|            1|     [B, C]| 0.3|            2|
+-----------+----+-------------+-----------+----+-------------+