Spark:如何查询列中的数组?

时间:2017-10-19 15:06:11

标签: scala apache-spark

请我是Spark的新手(Stackoverflow)。对于以下RDD和DataFrame(相同的数据),我想获得观看次数最多的播放列表中包含超过N个视频的标签。我的问题是标签是在一个数组中,另外我不知道从哪里开始,因为它看起来很先进。

RDD

(id,playlist,tags,videos,views)
(1,playlist_1,[t1, t2, t3],9,200)
(2,playlist_2,[t4, t5, t7],64,793)
(3,playlist_3,[t4, t6, t3],51,114)
(4,playlist_4,[t1, t6, t2],8,115)
(5,playlist_5,[t1, t6, t2],51,256)
(2,playlist_6,[t4, t5, t2],66,553)
(3|playlist_7,[t4, t6, t2],77,462)

数据帧

+---+------------+--------------+--------+-------+
| id| playlist   | tags         | videos | views |
+---+------------+--------------+--------+-------+
| 1 | playlist_1 | [t1, t2, t3] | 9      |  200  |
| 2 | playlist_2 | [t4, t5, t7] | 64     |  793  |
| 3 | playlist_3 | [t4, t6, t3] | 51     |  114  |
| 4 | playlist_4 | [t1, t6, t2] | 8      |  115  |
| 5 | playlist_5 | [t1, t6, t2] | 51     |  256  |
| 2 | playlist_6 | [t4, t5, t2] | 66     |  553  |
| 3 | playlist_7 | [t4, t6, t2] | 77     |  462  |
+---+-------------+-------------+--------+-------+

预期结果

包含超过(N = 65)个视频的播放列表的标记

+-----+-------+
| tag | views |
+-----+-------+
| t2  | 1015  |
| t4  | 1015  |
| t5  | 553   |
| t6  | 462   |
+-----+-------+

1 个答案:

答案 0 :(得分:1)

以下是使用DataFrames的解决方案:

import org.apache.spark.sql.functions._
import spark.implicits._

val N = 65

val result = df.where($"videos" > N)           // filter playlists with enough views
  .select(explode($"tags") as "tag", $"views") // explode tags into separate records
  .groupBy("tag")                              // group by tag
  .sum("views")                                // sum views per tag

result.show(false)
// +---+----------+
// |tag|sum(views)|
// +---+----------+
// |t5 |553       |
// |t4 |1015      |
// |t2 |1015      |
// |t6 |462       |
// +---+----------+

使用RDD:

// given 
val rdd: RDD[(Int, String, Array[String], Int, Int)] = ???

val N = 65

val result: RDD[(String, Int)] = rdd
  .filter(_._4 > N)
  .flatMap { case (_, _, tags, _, views) => tags.map(tag => (tag, views)) }
  .reduceByKey(_ + _)