Java UDF处理数组列

时间:2020-08-02 08:09:56

标签: apache-spark

我正在编写Java udf以处理数组类型列。

目的是处理一个字符串数组以选择长度最短的字符串

sqlContext.udf().register("NAME_SELECTOR", (UDF1<List<String>, String>) brandNames -> {
                          brandNames.sort(Comparator.comparing(String::length));
                          return brandNames.get(0);},DataTypes.StringType);

该错误与UDF函数的输入类型有关。我知道在scala中,我需要使用Seq[String]作为输入类型,在Java中如何?

这是错误消息:

java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to java.util.List

1 个答案:

答案 0 :(得分:0)

试试这个-

使用 scala.collection.mutable.WrappedArray 并使用JavaConverters将其转换为Java列表,然后使用比较器对其进行排序并获取第一个最短的字符串-

Dataset<Row> df = spark.sql("select array('abc', 'ab', 'a') arr");
        df.printSchema();
        df.show(false);
        /**
         * root
         *  |-- arr: array (nullable = false)
         *  |    |-- element: string (containsNull = false)
         *
         * +------------+
         * |arr         |
         * +------------+
         * |[abc, ab, a]|
         * +------------+
         */

        // scala.collection.mutable.WrappedArray
        UserDefinedFunction shortestStringUdf = udf((WrappedArray<String> arr)  -> {
                    List<String> strings = new ArrayList<>(JavaConverters
                            .asJavaCollectionConverter(arr)
                            .asJavaCollection());
                    strings.sort(Comparator.comparing(String::length));
                    return strings.get(0);
                }
                , DataTypes.StringType);
        spark.udf().register("shortestString", shortestStringUdf);

        df.withColumn("a", expr("shortestString(arr)"))
        .show(false);
        /**
         * +------------+---+
         * |arr         |a  |
         * +------------+---+
         * |[abc, ab, a]|a  |
         * +------------+---+
         */

如果您在 spark>=2.4 上,请使用高阶函数获得与以下相同的结果 without udf -

 // spark>=2.4
        df.withColumn("arr_length", expr("TRANSFORM(arr, x -> length(x))"))
                .withColumn("a", expr("array_sort(arrays_zip(arr_length, arr))[0].arr"))
                .show(false);
        /**
         * +------------+----------+---+
         * |arr         |arr_length|a  |
         * +------------+----------+---+
         * |[abc, ab, a]|[3, 2, 1] |a  |
         * +------------+----------+---+
         */