数据框下面有2列,
要求是在user_id_list中找到user_id的位置。
样本记录:
user_id = x1
user_id_list = ('X2','X1','X3','X6')
结果:
postition = 2
我需要带有第3列的数据框,该列在列表中具有user_id的位置。
结果数据框列:
使用find_in_set()
将数据框注册为视图后,我可以使用createOrReplaceTempView
hive函数实现此目的。
在没有注册视图的情况下,是否可以在spark中使用sql函数来完成此操作?
答案 0 :(得分:1)
我的建议是实施UDF,就像Yura提到的那样。这是一个简短的例子:
val spark = SparkSession.builder.getOrCreate()
import spark.implicits._
val df = List((1, Array(2, 3, 1)), (2, Array(1, 2,3))).toDF("user_id","user_id_list")
df.show
+-------+------------+
|user_id|user_id_list|
+-------+------------+
| 1| [2, 3, 1]|
| 2| [1, 2, 3]|
+-------+------------+
val findPosition = udf((user_id: Int, user_id_list: Seq[Int]) => {
user_id_list.indexOf(user_id)
})
val df2 = df.withColumn("position", findPosition($"user_id", $"user_id_list"))
df2.show
+-------+------------+--------+
|user_id|user_id_list|position|
+-------+------------+--------+
| 1| [2, 3, 1]| 2|
| 2| [1, 2, 3]| 1|
+-------+------------+--------+
答案 1 :(得分:0)
我不知道这样的功能是Spark SQL API。有一个函数可以找到数组是否包含一个值(称为array_contains
),但这不是你需要的。
您可以使用posexplode
将数组分解为具有位置的行,然后按其进行过滤,如下所示:dataframe.select($"id", posexplode($"ids")).filter($"id" === $"col").select($"id", $"pos")
。仍然可能不是最佳解决方案,具体取决于用户ID列表的长度。目前(对于版本2.1.1)Spark不进行优化以使用直接数组查找替换上述代码 - 它将生成行并按其过滤。
另请注意,此方法会过滤掉user_id
中user_ids_list
不在else
的任何行,因此您可能需要付出额外的努力来克服这个问题。
我建议实施完全符合您需要的UDF。缺点是:Spark无法查看UDF,因此它必须将数据反序列化为Java对象并返回。
答案 2 :(得分:0)
在没有注册视图的情况下,是否可以在spark中使用sql函数来完成此操作?
不,但您不必注册DataFrame以使用find_in_set
。
您可以(暂时)使用expr
函数切换到SQL模式(请参阅functions对象):
将表达式字符串解析为它所代表的列
val users = Seq(("x1", Array("X2","X1","X3","X6"))).toDF("user_id", "user_id_list")
val positions = users.
as[(String, Array[String])].
map { case (uid, ids) => (uid, ids, ids.mkString(",")) }.
toDF("user_id", "user_id_list", "ids"). // only for nicer column names
withColumn("position", expr("find_in_set(upper(user_id), ids)")).
select("user_id", "user_id_list", "position")
scala> positions.show
+-------+----------------+--------+
|user_id| user_id_list|position|
+-------+----------------+--------+
| x1|[X2, X1, X3, X6]| 2|
+-------+----------------+--------+
您还可以使用posexplode
函数(来自functions对象)来节省一些Scala自定义编码,并且比UDF更优化(强制将内部二进制行反序列化为JVM对象)。
scala> users.
select('*, posexplode($"user_id_list")).
filter(lower($"user_id") === lower($"col")).
select($"user_id", $"user_id_list", $"pos" as "position").
show
+-------+----------------+--------+
|user_id| user_id_list|position|
+-------+----------------+--------+
| x1|[X2, X1, X3, X6]| 1|
+-------+----------------+--------+