PySpark to Scala:带有StructType的UDF,GenericRowWithSchema无法强制转换为org.apache.spark.sql.Column

时间:2016-11-19 16:58:15

标签: scala apache-spark pyspark

我有一些用PySpark编写的代码,我正忙着将它转换为Scala。它一直很顺利,除非我现在在Scala中努力使用用户定义的函数。

from pyspark.sql import SparkSession
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql import functions as F

spark = SparkSession.builder.master('local[*]').getOrCreate()

a = spark.sparkContext.parallelize([(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)]).toDF(["index"]).withColumn("a1", F.lit(1)).withColumn("a2", F.lit(2)).withColumn("a3", F.lit(3))

a = a.select("index", F.struct(*('a' + str(c) for c in range(1, 4))).alias('a'))

a.show()

def a_to_b(a):
    # 1. check if a technical cure exists
    b = {}
    for i in range(1, 4):
        b.update({'b' + str(i): a[i - 1] ** 2})
    return b

a_to_b_udf = F.udf(lambda x: a_to_b(x), StructType(list(StructField("b" + str(x), IntegerType()) for x in range(1, 4))))

b = a.select("index", "a", a_to_b_udf(a.a).alias("b"))

b.show()

这会产生:

+-----+-------+
|index|      a|
+-----+-------+
|    1|[1,2,3]|
|    2|[1,2,3]|
|    3|[1,2,3]|
|    4|[1,2,3]|
|    5|[1,2,3]|
|    6|[1,2,3]|
|    7|[1,2,3]|
|    8|[1,2,3]|
|    9|[1,2,3]|
|   10|[1,2,3]|
+-----+-------+

+-----+-------+-------+
|index|      a|      b|
+-----+-------+-------+
|    1|[1,2,3]|[1,4,9]|
|    2|[1,2,3]|[1,4,9]|
|    3|[1,2,3]|[1,4,9]|
|    4|[1,2,3]|[1,4,9]|
|    5|[1,2,3]|[1,4,9]|
|    6|[1,2,3]|[1,4,9]|
|    7|[1,2,3]|[1,4,9]|
|    8|[1,2,3]|[1,4,9]|
|    9|[1,2,3]|[1,4,9]|
|   10|[1,2,3]|[1,4,9]|
+-----+-------+-------+

Scala的

import org.apache.spark.sql._
import org.apache.spark.sql.functions._

// can ignore if running on spark-shell
val spark: SparkSession = SparkSession.builder()
  .master("local[*]")
  .getOrCreate()

import spark.implicits._

var a = spark.sparkContext.parallelize(1 to 10).toDF("index").withColumn("a1", lit(1)).withColumn("a2", lit(2)).withColumn("a3", lit(3))

// convert a{x} to struct column
a = a.select($"index", struct((1 to 3).map {x => col("a" + x)}.toList:_*).alias("a"))

a.show()

// this is where I am struggling, I have tried supplying a schema, but still get errors
val f = udf((a: Column) => {
  Seq(Math.pow(a(0).asInstanceOf[Double], 2), Math.pow(a(1).asInstanceOf[Double], 2), Math.pow(a(2).asInstanceOf[Double], 2))
})

val b = a.select($"index", $"a", f($"a").alias("b"))

// throws the below error
b.show()

我可以显示()第一个DataFrame,但在尝试显示b时出现了转换错误。

错误是:

java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to org.apache.spark.sql.Column
  at $line23.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:31)
  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
  at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
  at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:283)
  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)
  at org.apache.spark.scheduler.Task.run(Task.scala:85)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
  at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
  at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
  at java.lang.Thread.run(Thread.java:745)

我已尝试为我的UDF设置架构,就像我在Python中完成的那样,但我仍然遇到同样的错误。

有谁知道如何解决这个问题?我的例子很简单,但在返回结构之前,我需要对UDF做的是很多转换。

1 个答案:

答案 0 :(得分:4)

我感到非常愚蠢,因为自从周五下午以来我一直在努力。

来自Spark Sql UDF with complex input parameter

  

结构类型转换为org.apache.spark.sql.Row

我的问题是我提供给我的函数的Column类型。

val f = udf((a: Column) => {
  Seq(Math.pow(a(0).asInstanceOf[Double], 2), Math.pow(a(1).asInstanceOf[Double], 2), Math.pow(a(2).asInstanceOf[Double], 2))
})

我应该使用Row代替。

val f = udf((a: Row) => {
  println("testing")
  Seq(Math.pow(a(0).asInstanceOf[Int], 2).asInstanceOf[Int],
    Math.pow(a(1).asInstanceOf[Int], 2).asInstanceOf[Int],
    Math.pow(a(2).asInstanceOf[Int], 2).asInstanceOf[Int])
})