我的Spark数据集中有一组动态的列。我想传递列数组而不是单独的列。我们如何编写UDF函数,使其接受列数组。
我尝试传递字符串序列,但是失败。
static UDF1<Seq<String>, String> udf = new UDF1<Seq<String>, String>() {
@Override
public String call(Seq<String> t1) throws Exception {
return t1.toString();
}
};
private static Column generate(Dataset<Row> dataset, SparkSession ss) {
ss.udf().register("generate", udf, DataTypes.StringType);
StructField[] columnsStructType = dataset.schema().fields();
List<Column> columnList = new ArrayList<>();
for (StructField structField : columnsStructType) {
columnList.add(dataset.col(structField.name()));
}
return functions.callUDF("generate", convertListToSeq(columnList));
}
private static Seq<Column> convertListToSeq(List<Column> inputList) {
return JavaConverters.asScalaIteratorConverter(inputList.iterator()).asScala().toSeq();
}
当我尝试调用生成函数时,我收到以下错误消息
Exception in thread "main" org.apache.spark.sql.AnalysisException: Invalid number of arguments for function generate. Expected: 1; Found: 14;
at org.apache.spark.sql.UDFRegistration.builder$27(UDFRegistration.scala:763)
at org.apache.spark.sql.UDFRegistration.$anonfun$register$377(UDFRegistration.scala:766)
at org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry.lookupFunction(FunctionRegistry.scala:115)
at org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupFunction(SessionCatalog.scala:1273)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$$anonfun$apply$13$$anonfun$applyOrElse$143.$anonfun$applyOrElse$66(Analyzer.scala:1329)
at org.apache.spark.sql.catalyst.analysis.package$.withPosition(package.scala:53)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$$anonfun$apply$13$$anonfun$applyOrElse$143.applyOrElse(Analyzer.scala:1329)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$$anonfun$apply$13$$anonfun$applyOrElse$143.applyOrElse(Analyzer.scala:1312)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$1(TreeNode.scala:256)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:256)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$3(TreeNode.scala:261)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:326)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:324)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:261)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$transformExpressionsDown$1(QueryPlan.scala:83)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$1(QueryPlan.scala:105)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpression$1(QueryPlan.scala:105)
at org.apache.spark.sql.catalyst.plans.QueryPlan.recursiveTransform$1(QueryPlan.scala:116)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$3(QueryPlan.scala:121)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:233)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:58)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:51)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
at scala.collection.TraversableLike.map(TraversableLike.scala:233)
at scala.collection.TraversableLike.map$(TraversableLike.scala:226)
at scala.collection.AbstractTraversable.map(Traversable.scala:104)
at org.apache.spark.sql.catalyst.plans.QueryPlan.recursiveTransform$1(QueryPlan.scala:121)
at org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$mapExpressions$4(QueryPlan.scala:126)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:126)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsDown(QueryPlan.scala:83)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressions(QueryPlan.scala:74)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$$anonfun$apply$13.applyOrElse(Analyzer.scala:1312)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$$anonfun$apply$13.applyOrElse(Analyzer.scala:1310)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUp$3(AnalysisHelper.scala:90)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUp$1(AnalysisHelper.scala:90)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.allowInvokingTransformsInAnalyzer(AnalysisHelper.scala:194)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp(AnalysisHelper.scala:86)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp$(AnalysisHelper.scala:84)
at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperatorsUp(LogicalPlan.scala:29)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$.apply(Analyzer.scala:1310)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions$.apply(Analyzer.scala:1309)
at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$2(RuleExecutor.scala:87)
at scala.collection.LinearSeqOptimized.foldLeft(LinearSeqOptimized.scala:122)
at scala.collection.LinearSeqOptimized.foldLeft$(LinearSeqOptimized.scala:118)
at scala.collection.immutable.List.foldLeft(List.scala:85)
at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:84)
at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:76)
at scala.collection.immutable.List.foreach(List.scala:388)
at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:76)
at org.apache.spark.sql.catalyst.analysis.Analyzer.org$apache$spark$sql$catalyst$analysis$Analyzer$$executeSameContext(Analyzer.scala:127)
at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:121)
at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:106)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:201)
at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:105)
at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:57)
at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:55)
at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:47)
at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:79)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withPlan(Dataset.scala:3407)
at org.apache.spark.sql.Dataset.select(Dataset.scala:1336)
at org.apache.spark.sql.Dataset.withColumns(Dataset.scala:2253)
at org.apache.spark.sql.Dataset.withColumn(Dataset.scala:2220)
答案 0 :(得分:0)
简而言之:在将列传递到UDF之前,应使用array
方法将列合并为一个结构。
此代码应该可以工作(经过一些重构后才是实际的工作代码)。
//
// The UDF function implementation
//
static String myFunc(Seq<Object> values) {
Iterator<Object> iterator = values.iterator();
while (iterator.hasNext()) {
Object object = iterator.next();
// Do something with your column value
}
return ...;
}
//
// UDF registration; `sc` here is the Spark SQL context
//
sc.udf().register("myFunc", (UDF1<Seq<Object>, String>) myFunc, DataTypes.StringType);
//
// Calling the UDF; note the `array` method
//
Dataset<Row> ds = ...;
Seq<Column> columns = JavaConversions.asScalaBuffer(Stream
.of(ds.schema().fields())
.map(f -> col(f.name()))
.collect(Collectors.toList()));
ds = ds.withColumn("myColumn", callUDF("myFunc", array(columns)));