使用Scala过滤SPARK数据框中的未爆炸结构

时间:2018-09-27 09:42:42

标签: scala apache-spark

我有:

 +-----------------------+-------+------------------------------------+
 |cities                 |name   |schools                             |
 +-----------------------+-------+------------------------------------+
 |[palo alto, menlo park]|Michael|[[stanford, 2010], [berkeley, 2012]]|
 |[santa cruz]           |Andy   |[[ucsb, 2011]]                      |
 |[portland]             |Justin |[[berkeley, 2014]]                  |
 +-----------------------+-------+------------------------------------+

我不费吹灰之力

 val res = df.select ("*").where (array_contains (df("schools.sname"), "berkeley")).show(false)

但是我不想爆炸或使用UDF,我以与上述相同或类似的方式,该怎么做:

 return all rows where at least 1 schools.sname starts with "b"  ?

例如:

 val res = df.select ("*").where (startsWith (df("schools.sname"), "b")).show(false)

这当然是错误的,只是为了证明这一点。但是,如何在不爆炸或UDF使用率返回true / false或不进行UDF使用的情况下进行过滤的情况下做类似的事情?可能是不可能的。我找不到任何这样的例子。还是我需要 expr

获得的答案显示了某些事情如何采用某种方法,因为SCALA中不存在某些功能。我读了一篇文章,指出在此之后要实现的新数组功能,因此证明了这一点。

2 个答案:

答案 0 :(得分:1)

我不确定这是否符合UDF的标准,但是您可以定义一个新的过滤器功能。如果使用Dataset[Student],其中:

case class School(sname: String, year: Int)
case class Student(cities: Seq[String], name: String, schools: Seq[School])

然后,您可以简单地执行以下操作:

students
    .filter(
        r => r.schools.filter(_.sname.startsWith("b")).size > 0)

但是,如果您仅使用DataFrame,则:

import org.apache.spark.sql.Row

students.toDF
    .filter(
        r => r.getAs[Seq[Row]]("schools").filter(_.getAs[String]("name")
                                         .startsWith("b")).size > 0)

两者都会导致:

+-----------------------+-------+------------------------------------+
|cities                 |name   |schools                             |
+-----------------------+-------+------------------------------------+
|[palo alto, menlo park]|Michael|[[stanford, 2010], [berkeley, 2012]]|
|[portland]             |Justin |[[berkeley, 2014]]                  |
+-----------------------+-------+------------------------------------+

答案 1 :(得分:1)

如何?

scala> val df = Seq ( ( Array("palo alto", "menlo park"), "Michael", Array(("stanford", 2010), ("berkeley", 2012))),
     |     (Array(("santa cruz")),"Andy",Array(("ucsb", 2011))),
     |       (Array(("portland")),"Justin",Array(("berkeley", 2014)))
     |     ).toDF("cities","name","schools")
df: org.apache.spark.sql.DataFrame = [cities: array<string>, name: string ... 1 more field]

scala> val df2 = df.select ("*").withColumn("sch1",df("schools._1"))
df2: org.apache.spark.sql.DataFrame = [cities: array<string>, name: string ... 2 more fields]

scala> val df3=df2.select("*").withColumn("sch2",concat_ws(",",df2("sch1")))
df3: org.apache.spark.sql.DataFrame = [cities: array<string>, name: string ... 3 more fields]

scala> df3.select("*").where( df3("sch2") rlike "^b|,b" ).show(false)
+-----------------------+-------+------------------------------------+--------------------+-----------------+
|cities                 |name   |schools                             |sch1                |sch2             |
+-----------------------+-------+------------------------------------+--------------------+-----------------+
|[palo alto, menlo park]|Michael|[[stanford, 2010], [berkeley, 2012]]|[stanford, berkeley]|stanford,berkeley|
|[portland]             |Justin |[[berkeley, 2014]]                  |[berkeley]          |berkeley         |
+-----------------------+-------+------------------------------------+--------------------+-----------------+

又一步,您可以删除不需要的列。