如何将列数组传递给Java中的Spark用户定义函数?

时间:2019-07-03 12:01:07

标签: java apache-spark

我的Spark数据集中有一组动态的列。我想传递列数组而不是单独的列。我们如何编写UDF函数,使其接受列数组。

我尝试传递字符串序列,但是失败。

    static UDF1<Seq<String>, String> udf = new UDF1<Seq<String>, String>() {

        @Override
        public String call(Seq<String> t1) throws Exception {
            return t1.toString();
        }
    };

    private static Column generate(Dataset<Row> dataset, SparkSession ss) {

        ss.udf().register("generate", udf, DataTypes.StringType);

        StructField[] columnsStructType = dataset.schema().fields();

        List<Column> columnList = new ArrayList<>();

        for (StructField structField : columnsStructType) {
            columnList.add(dataset.col(structField.name()));
        }

        return functions.callUDF("generate", convertListToSeq(columnList));
    }

    private static Seq<Column> convertListToSeq(List<Column> inputList) {
        return JavaConverters.asScalaIteratorConverter(inputList.iterator()).asScala().toSeq();
    }

当我尝试调用生成函数时,我收到以下错误消息

Exception in thread "main" org.apache.spark.sql.AnalysisException: Invalid number of arguments for function generate. Expected: 1; Found: 14;
    at org.apache.spark.sql.UDFRegistration.builder$27(UDFRegistration.scala:763)
    at org.apache.spark.sql.UDFRegistration.$anonfun$register$377(UDFRegistration.scala:766)
    at org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry.lookupFunction(FunctionRegistry.scala:115)
    at org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupFunction(SessionCatalog.scala:1273)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$$anonfun$apply$13$$anonfun$applyOrElse$143.$anonfun$applyOrElse$66(Analyzer.scala:1329)
    at org.apache.spark.sql.catalyst.analysis.package$.withPosition(package.scala:53)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$$anonfun$apply$13$$anonfun$applyOrElse$143.applyOrElse(Analyzer.scala:1329)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$$anonfun$apply$13$$anonfun$applyOrElse$143.applyOrElse(Analyzer.scala:1312)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$1(TreeNode.scala:256)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:256)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$3(TreeNode.scala:261)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:326)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:324)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:261)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$transformExpressionsDown$1(QueryPlan.scala:83)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$1(QueryPlan.scala:105)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpression$1(QueryPlan.scala:105)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.recursiveTransform$1(QueryPlan.scala:116)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$3(QueryPlan.scala:121)
    at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:233)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:58)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:51)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
    at scala.collection.TraversableLike.map(TraversableLike.scala:233)
    at scala.collection.TraversableLike.map$(TraversableLike.scala:226)
    at scala.collection.AbstractTraversable.map(Traversable.scala:104)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.recursiveTransform$1(QueryPlan.scala:121)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$4(QueryPlan.scala:126)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:126)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsDown(QueryPlan.scala:83)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressions(QueryPlan.scala:74)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$$anonfun$apply$13.applyOrElse(Analyzer.scala:1312)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$$anonfun$apply$13.applyOrElse(Analyzer.scala:1310)
    at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUp$3(AnalysisHelper.scala:90)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
    at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUp$1(AnalysisHelper.scala:90)
    at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.allowInvokingTransformsInAnalyzer(AnalysisHelper.scala:194)
    at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp(AnalysisHelper.scala:86)
    at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp$(AnalysisHelper.scala:84)
    at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperatorsUp(LogicalPlan.scala:29)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$.apply(Analyzer.scala:1310)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$.apply(Analyzer.scala:1309)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$2(RuleExecutor.scala:87)
    at scala.collection.LinearSeqOptimized.foldLeft(LinearSeqOptimized.scala:122)
    at scala.collection.LinearSeqOptimized.foldLeft$(LinearSeqOptimized.scala:118)
    at scala.collection.immutable.List.foldLeft(List.scala:85)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:84)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:76)
    at scala.collection.immutable.List.foreach(List.scala:388)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:76)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.org$apache$spark$sql$catalyst$analysis$Analyzer$$executeSameContext(Analyzer.scala:127)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:121)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:106)
    at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:201)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:105)
    at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:57)
    at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:55)
    at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:47)
    at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:79)
    at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withPlan(Dataset.scala:3407)
    at org.apache.spark.sql.Dataset.select(Dataset.scala:1336)
    at org.apache.spark.sql.Dataset.withColumns(Dataset.scala:2253)
    at org.apache.spark.sql.Dataset.withColumn(Dataset.scala:2220)

1 个答案:

答案 0 :(得分:0)

简而言之:在将列传递到UDF之前,应使用array方法将列合并为一个结构。

此代码应该可以工作(经过一些重构后才是实际的工作代码)。

//
//  The UDF function implementation
//
static String myFunc(Seq<Object> values) {
    Iterator<Object> iterator = values.iterator();
    while (iterator.hasNext()) {
        Object object = iterator.next();
        //  Do something with your column value
    }
    return ...;
}

// 
//  UDF registration; `sc` here is the Spark SQL context
//
sc.udf().register("myFunc", (UDF1<Seq<Object>, String>) myFunc, DataTypes.StringType);

//
//  Calling the UDF; note the `array` method
//
Dataset<Row> ds = ...;

Seq<Column> columns = JavaConversions.asScalaBuffer(Stream
        .of(ds.schema().fields())
        .map(f -> col(f.name()))
        .collect(Collectors.toList()));
ds = ds.withColumn("myColumn", callUDF("myFunc", array(columns)));