我有一些用PySpark编写的代码,我正忙着将它转换为Scala。它一直很顺利,除非我现在在Scala中努力使用用户定义的函数。
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql import functions as F
spark = SparkSession.builder.master('local[*]').getOrCreate()
a = spark.sparkContext.parallelize([(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)]).toDF(["index"]).withColumn("a1", F.lit(1)).withColumn("a2", F.lit(2)).withColumn("a3", F.lit(3))
a = a.select("index", F.struct(*('a' + str(c) for c in range(1, 4))).alias('a'))
a.show()
def a_to_b(a):
# 1. check if a technical cure exists
b = {}
for i in range(1, 4):
b.update({'b' + str(i): a[i - 1] ** 2})
return b
a_to_b_udf = F.udf(lambda x: a_to_b(x), StructType(list(StructField("b" + str(x), IntegerType()) for x in range(1, 4))))
b = a.select("index", "a", a_to_b_udf(a.a).alias("b"))
b.show()
这会产生:
+-----+-------+
|index| a|
+-----+-------+
| 1|[1,2,3]|
| 2|[1,2,3]|
| 3|[1,2,3]|
| 4|[1,2,3]|
| 5|[1,2,3]|
| 6|[1,2,3]|
| 7|[1,2,3]|
| 8|[1,2,3]|
| 9|[1,2,3]|
| 10|[1,2,3]|
+-----+-------+
和
+-----+-------+-------+
|index| a| b|
+-----+-------+-------+
| 1|[1,2,3]|[1,4,9]|
| 2|[1,2,3]|[1,4,9]|
| 3|[1,2,3]|[1,4,9]|
| 4|[1,2,3]|[1,4,9]|
| 5|[1,2,3]|[1,4,9]|
| 6|[1,2,3]|[1,4,9]|
| 7|[1,2,3]|[1,4,9]|
| 8|[1,2,3]|[1,4,9]|
| 9|[1,2,3]|[1,4,9]|
| 10|[1,2,3]|[1,4,9]|
+-----+-------+-------+
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
// can ignore if running on spark-shell
val spark: SparkSession = SparkSession.builder()
.master("local[*]")
.getOrCreate()
import spark.implicits._
var a = spark.sparkContext.parallelize(1 to 10).toDF("index").withColumn("a1", lit(1)).withColumn("a2", lit(2)).withColumn("a3", lit(3))
// convert a{x} to struct column
a = a.select($"index", struct((1 to 3).map {x => col("a" + x)}.toList:_*).alias("a"))
a.show()
// this is where I am struggling, I have tried supplying a schema, but still get errors
val f = udf((a: Column) => {
Seq(Math.pow(a(0).asInstanceOf[Double], 2), Math.pow(a(1).asInstanceOf[Double], 2), Math.pow(a(2).asInstanceOf[Double], 2))
})
val b = a.select($"index", $"a", f($"a").alias("b"))
// throws the below error
b.show()
我可以显示()第一个DataFrame,但在尝试显示b
时出现了转换错误。
错误是:
java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to org.apache.spark.sql.Column
at $line23.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:31)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:283)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)
at org.apache.spark.scheduler.Task.run(Task.scala:85)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
我已尝试为我的UDF设置架构,就像我在Python中完成的那样,但我仍然遇到同样的错误。
有谁知道如何解决这个问题?我的例子很简单,但在返回结构之前,我需要对UDF做的是很多转换。
答案 0 :(得分:4)
我感到非常愚蠢,因为自从周五下午以来我一直在努力。
来自Spark Sql UDF with complex input parameter,
结构类型转换为
org.apache.spark.sql.Row
我的问题是我提供给我的函数的Column
类型。
val f = udf((a: Column) => {
Seq(Math.pow(a(0).asInstanceOf[Double], 2), Math.pow(a(1).asInstanceOf[Double], 2), Math.pow(a(2).asInstanceOf[Double], 2))
})
我应该使用Row
代替。
val f = udf((a: Row) => {
println("testing")
Seq(Math.pow(a(0).asInstanceOf[Int], 2).asInstanceOf[Int],
Math.pow(a(1).asInstanceOf[Int], 2).asInstanceOf[Int],
Math.pow(a(2).asInstanceOf[Int], 2).asInstanceOf[Int])
})