我将以下函数迁移到sql udf spark的函数。
DROP FUNCTION IF EXISTS anyarray_enumerate(anyarray);
CREATE FUNCTION anyarray_enumerate(anyarray)
RETURNS TABLE (index bigint, value anyelement) AS
$$
SELECT
row_number() OVER (),
value
FROM (
SELECT unnest($1) AS value
) AS unnested
$$
LANGUAGE sql IMMUTABLE;
我没有得到spark sql输出类似于SQL中获得的输出。任何帮助或想法?
demo=# select anyarray_enumerate(array[599,322,119,537]);
anyarray_enumerate
--------------------
(1,599)
(2,322)
(3,119)
(4,537)
(4 rows)
我目前的代码是:
def anyarray_enumerate[T](anyarray: WrappedArray[T]) = anyarray.zipWithIndex
// Registers a function as a UDF so it can be used in SQL statements.
sqlContext.udf.register("anyarray_enumerate", anyarray_enumerate(_:WrappedArray[Int]))
谢谢
答案 0 :(得分:1)
你的UDF在一行中返回整个元组数组:
spark.sql("select anyarray_enumerate(array(599, 322, 119, 537)) as foo").show()
+--------------------+
| foo|
+--------------------+
|[[599,0], [322,1]...|
+--------------------+
但您可以使用explode()
函数将其拆分为多行:
spark.sql("select explode(anyarray_enumerate(array(599, 322, 119, 537))) as foo").show()
+-------+
| foo|
+-------+
|[599,0]|
|[322,1]|
|[119,2]|
|[537,3]|
+-------+
此外,zipWithIndex
方法返回值first和index second,与SQL命令不同,但在UDF中很容易修复。