6我正在尝试编写一个UDF,它接受一个字符串数组,如下所示:
String[] lol = {"1,2,3","1,2,3","2,3,4","1,4,5,6,7"};
我希望UDF返回一个没有重复的整数数组。
我首先在我的查询表单中收集一个DataFrame,它有两个字段,userid和category,这是一个看起来像" 1,2,3"并按用户分组。
df.groupBy("userid").agg(collect_list("category").as("categories")).write().mode(SaveMode.Overwrite).parquet("path");
然后我想运行我的UDF:
ctx.read().parquet("path").select(col("userid"), trimCategories3("categories", ctx).as("categories")).show();
我的UDF:
public static Column trimCategories3(String column, SQLContext ctx) {
UDF1 udf = new UDF1<String[], Integer[]>() {
@Override
public Integer[] call(String[] categories) throws Exception {
Set<Integer> result = new HashSet<>();
for(String s : categories) {
Set<Integer> med = Arrays.stream(s.split("\\,"))
.map(Integer::parseInt)
.collect(Collectors.toSet());
result.addAll(med);
}
return result.toArray(new Integer[result.size()]);
}
};
ctx.udf().register("trimCategories", udf, DataTypes.createArrayType(DataTypes.IntegerType));
return callUDF("trimCategories", col(column));
}
这给了我: java.lang.ClassCastException:scala.collection.mutable.WrappedArray $ ofRef无法强制转换为[Ljava.lang.String;
由于我是编程新手并且不了解Scala,我可以使用一些帮助。在火花错误日志中,我得到了UDF1 udf = new UDF1<String[], Integer[]>() {
开始的行号。
当我尝试在测试类中运行它时,代码可以正常工作。会很感激一些指导。
干杯!
答案 0 :(得分:0)
发现问题,即UDF的输入类型是Scala WrappedArray。奇怪的部分(对我来说)是collect_list函数(在文档中应该返回一个列表)返回一个Scala Wrapped数组,当我之前运行printSchema函数时它表示类型array:string。这就是为什么我将UDF的Input类型设置为String []的原因。代码中的解决方案:
public static Column trimCategories3(String column, SQLContext ctx) {
UDF1 udf = new UDF1<WrappedArray<String>, Integer[]>() {
@Override
public Integer[] call(WrappedArray<String> categories) throws Exception {
Set<Integer> result = new HashSet<>();
scala.collection.Iterator it = categories.iterator();
while (it.hasNext()) {
String s = (String) it.next();
Set<Integer> med = Arrays.stream(s.split("\\,"))
.map(Integer::parseInt)
.collect(Collectors.toSet());
result.addAll(med);
}
return result.stream().toArray(Integer[]::new);
}
};
ctx.udf().register("trimCategories", udf, DataTypes.createArrayType(DataTypes.IntegerType));
return callUDF("trimCategories", col(column));
}