在Spark数据框中,如果另一列中的值在广播变量数组中,则向新列添加值

时间:2017-12-11 04:33:27

标签: scala apache-spark merge apache-spark-sql

我有一个数组作为广播变量,它包含整数:

broadcast_array.value
Array(72159153, 72159163, 72159202, 72159203, 72159238, 72159398, 72159447, 72159448, 72159455, 72159492...

我在数据集中有一个列(调用col_id,其中包含可能位于IntegerType中的broadcast_array值,但它们可能不会。

我只是想创建一个新列(称之为new_col),检查每行的col_id值是否在broadcast_array。如果是,则新列值应为Available,否则可为null

所以我有类似的东西:

val my_new_df = df.withColumn("new_col", when(broadcast_array.value.contains($"col_id"), "Available"))

但我一直收到这个错误:

Name: Unknown Error
Message: <console>:45: error: type mismatch;
 found   : Boolean
 required: org.apache.spark.sql.Column
   val my_new_df = df.withColumn("new_col", when(broadcast_array.value.contains($"col_id"), "Available"))
                                                                                           ^
StackTrace: 

最令我困惑的是,我认为when语句需要输出一些布尔值的条件,但是这里说它需要一个列。

我应该如何根据是否可以在预定义数组中找到现有列中的值来为新列添加值?

2 个答案:

答案 0 :(得分:0)

如果查看when函数

api
  

def when(condition:org.apache.spark.sql.Column,value:scala.Any):org.apache.spark.sql.Column

很明显所需的条件是一列而不是一个布尔

因此,您可以执行复杂的lit组合,将boolean转换为column

import org.apache.spark.sql.functions._
df.withColumn("new_col", when(lit(broadcast_array.value.mkString(",")).contains($"col_id"), lit("Available"))).show(false)

OR

您可以通过编写简单的udf函数

来实现您的目标
import org.apache.spark.sql.functions._
val broadcastContains = udf((id: Int) => broadcast_array.value.contains(id))

并将该函数调用为

df.withColumn("new_col", when(broadcastContains($"col_id"), lit("Available"))).show(false)

答案 1 :(得分:0)

我在spark-daria上添加了一个broadcastArrayContains函数,使Ramesh的解决方案更具可重用性/可访问性。

def broadcastArrayContains[T](col: Column, broadcastedArray: Broadcast[Array[T]]) = {
  when(col.isNull, null)
    .when(lit(broadcastedArray.value.mkString(",")).contains(col), lit(true))
    .otherwise(lit(false))
}

假设您具有以下DataFrame(df):

+----+
| num|
+----+
| 123|
|  hi|
|null| 
+----+

您可以按如下所示标识广播数组中的所有值:

val specialNumbers = spark.sparkContext.broadcast(Array("123", "456"))

df.withColumn(
  "is_special_number",
  functions.broadcastArrayContains[String](col("num"), specialNumbers)
)