返回返回整数数组的Java spark UDF给了我ClassException

时间:2017-03-14 16:50:29

标签: java scala apache-spark

6我正在尝试编写一个UDF,它接受一个字符串数组,如下所示:

String[] lol = {"1,2,3","1,2,3","2,3,4","1,4,5,6,7"};

我希望UDF返回一个没有重复的整数数组。

我首先在我的查询表单中收集一个DataFrame,它有两个字段,userid和category,这是一个看起来像" 1,2,3"并按用户分组。

df.groupBy("userid").agg(collect_list("category").as("categories")).write().mode(SaveMode.Overwrite).parquet("path"); 然后我想运行我的UDF:

ctx.read().parquet("path").select(col("userid"), trimCategories3("categories", ctx).as("categories")).show();

我的UDF:

public static Column trimCategories3(String column, SQLContext ctx) {
  UDF1 udf = new UDF1<String[], Integer[]>() {
    @Override
    public Integer[] call(String[] categories) throws Exception {
      Set<Integer> result = new HashSet<>();
      for(String s : categories) {
        Set<Integer> med = Arrays.stream(s.split("\\,"))
            .map(Integer::parseInt)
            .collect(Collectors.toSet());
        result.addAll(med);
      }
      return result.toArray(new Integer[result.size()]);
    }
  };
  ctx.udf().register("trimCategories", udf, DataTypes.createArrayType(DataTypes.IntegerType));
  return callUDF("trimCategories", col(column));
}

这给了我: java.lang.ClassCastException:scala.collection.mutable.WrappedArray $ ofRef无法强制转换为[Ljava.lang.String;

由于我是编程新手并且不了解Scala,我可以使用一些帮助。在火花错误日志中,我得到了UDF1 udf = new UDF1<String[], Integer[]>() {开始的行号。 当我尝试在测试类中运行它时,代码可以正常工作。会很感激一些指导。 干杯!

1 个答案:

答案 0 :(得分:0)

发现问题,即UDF的输入类型是Scala WrappedArray。奇怪的部分(对我来说)是collect_list函数(在文档中应该返回一个列表)返回一个Scala Wrapped数组,当我之前运行printSchema函数时它表示类型array:string。这就是为什么我将UDF的Input类型设置为String []的原因。代码中的解决方案:

public static Column trimCategories3(String column, SQLContext ctx) {
UDF1 udf = new UDF1<WrappedArray<String>, Integer[]>() {
  @Override
  public Integer[] call(WrappedArray<String> categories) throws Exception {
    Set<Integer> result = new HashSet<>();
    scala.collection.Iterator it = categories.iterator();
    while (it.hasNext()) {
      String s = (String) it.next();
      Set<Integer> med = Arrays.stream(s.split("\\,"))
          .map(Integer::parseInt)
          .collect(Collectors.toSet());
      result.addAll(med);
    }
    return result.stream().toArray(Integer[]::new);
  }
};
ctx.udf().register("trimCategories", udf, DataTypes.createArrayType(DataTypes.IntegerType));
return callUDF("trimCategories", col(column));
}