我有一个数组作为广播变量,它包含整数:
broadcast_array.value
Array(72159153, 72159163, 72159202, 72159203, 72159238, 72159398, 72159447, 72159448, 72159455, 72159492...
我在数据集中有一个列(调用col_id
,其中包含可能位于IntegerType
中的broadcast_array
值,但它们可能不会。
我只是想创建一个新列(称之为new_col
),检查每行的col_id
值是否在broadcast_array
。如果是,则新列值应为Available
,否则可为null
所以我有类似的东西:
val my_new_df = df.withColumn("new_col", when(broadcast_array.value.contains($"col_id"), "Available"))
但我一直收到这个错误:
Name: Unknown Error
Message: <console>:45: error: type mismatch;
found : Boolean
required: org.apache.spark.sql.Column
val my_new_df = df.withColumn("new_col", when(broadcast_array.value.contains($"col_id"), "Available"))
^
StackTrace:
最令我困惑的是,我认为when
语句需要输出一些布尔值的条件,但是这里说它需要一个列。
我应该如何根据是否可以在预定义数组中找到现有列中的值来为新列添加值?
答案 0 :(得分:0)
如果查看when
函数
def when(condition:org.apache.spark.sql.Column,value:scala.Any):org.apache.spark.sql.Column
很明显所需的条件是一列而不是一个布尔。
因此,您可以执行复杂的lit
组合,将boolean
转换为column
import org.apache.spark.sql.functions._
df.withColumn("new_col", when(lit(broadcast_array.value.mkString(",")).contains($"col_id"), lit("Available"))).show(false)
OR
您可以通过编写简单的udf
函数
import org.apache.spark.sql.functions._
val broadcastContains = udf((id: Int) => broadcast_array.value.contains(id))
并将该函数调用为
df.withColumn("new_col", when(broadcastContains($"col_id"), lit("Available"))).show(false)
答案 1 :(得分:0)
我在spark-daria上添加了一个broadcastArrayContains
函数,使Ramesh的解决方案更具可重用性/可访问性。
def broadcastArrayContains[T](col: Column, broadcastedArray: Broadcast[Array[T]]) = {
when(col.isNull, null)
.when(lit(broadcastedArray.value.mkString(",")).contains(col), lit(true))
.otherwise(lit(false))
}
假设您具有以下DataFrame(df
):
+----+
| num|
+----+
| 123|
| hi|
|null|
+----+
您可以按如下所示标识广播数组中的所有值:
val specialNumbers = spark.sparkContext.broadcast(Array("123", "456"))
df.withColumn(
"is_special_number",
functions.broadcastArrayContains[String](col("num"), specialNumbers)
)