我正在编写Java udf以处理数组类型列。
目的是处理一个字符串数组以选择长度最短的字符串
sqlContext.udf().register("NAME_SELECTOR", (UDF1<List<String>, String>) brandNames -> {
brandNames.sort(Comparator.comparing(String::length));
return brandNames.get(0);},DataTypes.StringType);
该错误与UDF函数的输入类型有关。我知道在scala中,我需要使用Seq[String]
作为输入类型,在Java中如何?
这是错误消息:
java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to java.util.List
答案 0 :(得分:0)
试试这个-
使用 scala.collection.mutable.WrappedArray
并使用JavaConverters
将其转换为Java列表,然后使用比较器对其进行排序并获取第一个最短的字符串-
Dataset<Row> df = spark.sql("select array('abc', 'ab', 'a') arr");
df.printSchema();
df.show(false);
/**
* root
* |-- arr: array (nullable = false)
* | |-- element: string (containsNull = false)
*
* +------------+
* |arr |
* +------------+
* |[abc, ab, a]|
* +------------+
*/
// scala.collection.mutable.WrappedArray
UserDefinedFunction shortestStringUdf = udf((WrappedArray<String> arr) -> {
List<String> strings = new ArrayList<>(JavaConverters
.asJavaCollectionConverter(arr)
.asJavaCollection());
strings.sort(Comparator.comparing(String::length));
return strings.get(0);
}
, DataTypes.StringType);
spark.udf().register("shortestString", shortestStringUdf);
df.withColumn("a", expr("shortestString(arr)"))
.show(false);
/**
* +------------+---+
* |arr |a |
* +------------+---+
* |[abc, ab, a]|a |
* +------------+---+
*/
如果您在 spark>=2.4
上,请使用高阶函数获得与以下相同的结果 without udf
-
// spark>=2.4
df.withColumn("arr_length", expr("TRANSFORM(arr, x -> length(x))"))
.withColumn("a", expr("array_sort(arrays_zip(arr_length, arr))[0].arr"))
.show(false);
/**
* +------------+----------+---+
* |arr |arr_length|a |
* +------------+----------+---+
* |[abc, ab, a]|[3, 2, 1] |a |
* +------------+----------+---+
*/