我们正在尝试编写一个scala udf函数,并从pyspark中的map函数调用它。 dateframe架构非常复杂,我们要传递给此函数的列是StructType的数组。
trip_force_speeds = trip_details.groupby("vehicle_id","driver_id", "StartDtLocal", "EndDtLocal")\
.agg(collect_list(struct(col("event_start_dt_local"),
col("force"),
col("speed"),
col("sec_from_start"),
col("sec_from_end"),
col("StartDtLocal"),
col("EndDtLocal"),
col("verisk_vehicle_id"),
col("trip_duration_sec")))\
.alias("trip_details"))
在我们的map函数中,我们需要做一些计算。
def calculateVariables(rec: Row):HashMap[String,Float] = {
val trips = rec.getAs[List]("trips")
val base_variables = new HashMap[String, Float]()
val entropy_variables = new HashMap[String, Float]()
val week_day_list = List("monday", "tuesday", "wednesday", "thursday", "friday")
for (trip <- trips)
{
if (trip("start_dt_local") >= trip("StartDtLocal") && trip("start_dt_local") <= trip("EndDtLocal"))
{
base_variables("trip_summary_count") += 1
if (trip("duration_sec").toFloat >= 300 && trip("duration_sec").toFloat <= 1800) {
base_variables ("bounded_trip") += 1
base_variables("bounded_trip_duration") = trip("duration_sec") + base_variables("bounded_trip_duration")
base_variables("total_bin_1") += 30
base_variables("total_bin_2") += 30
base_variables("total_bin_3") += 60
base_variables("total_bin_5") += 60
base_variables("total_bin_6") += 30
base_variables("total_bin_7") += 30
}
if (trip("duration_sec") > 120 && trip("duration_sec") < 21600 )
{
base_variables("trip_count") += 1
}
base_variables("trip_distance") += trip("distance_km")
base_variables("trip_duration") = trip("duration_sec") + base_variables("trip_duration")
base_variables("speed_event_distance") = trip("speed_event_distance_km") + base_variables("speed_event_distance")
base_variables("speed_event_duration") = trip("speed_event_duration_sec") + base_variables("speed_event_duration")
base_variables("speed_event_distance_ratio") = trip("speed_distance_ratio") + base_variables("speed_event_distance_ratio")
base_variables("speed_event_duration_ratio") = trip("speed_duration_ratio") + base_variables("speed_event_duration_ratio")
}
}
return base_variables
}
当我们尝试编译Scala代码时,出现错误
我尝试使用Row,但出现此错误
在我的情况下,“错误:类型参数(列表)的类型与类型参数(类型T)的预期类型不符。列表的类型参数与类型T的预期参数不匹配:类型列表具有一个类型参数,但类型T没有-“
此行是一个行列表。这是架构
StructType(List(StructField(verisk_vehicle_id,StringType,true),StructField(verisk_driver_id,StringType,false),StructField(StartDtLocal,TimestampType,true),StructField(EndDtLocal,TimestampType,true),StructField(trips,ArrayType(StructType(List(StructField(week_start_dt_local,TimestampType,true),StructField(week_end_dt_local,TimestampType,true),StructField(start_dt_local,TimestampType,true),StructField(end_dt_local,TimestampType,true),StructField(StartDtLocal,TimestampType,true),StructField(EndDtLocal,TimestampType,true),StructField(verisk_vehicle_id,StringType,true),StructField(duration_sec,FloatType,true),StructField(distance_km,FloatType,true),StructField(speed_distance_ratio,FloatType,true),StructField(speed_duration_ratio,FloatType,true),StructField(speed_event_distance_km,FloatType,true),StructField(speed_event_duration_sec,FloatType,true))),true),true),StructField(trip_details,ArrayType(StructType(List(StructField(event_start_dt_local,TimestampType,true),StructField(force,FloatType,true),StructField(speed,FloatType,true),StructField(sec_from_start,FloatType,true),StructField(sec_from_end,FloatType,true),StructField(StartDtLocal,TimestampType,true),StructField(EndDtLocal,TimestampType,true),StructField(verisk_vehicle_id,StringType,true),StructField(trip_duration_sec,FloatType,true))),true),true)))
我们定义的试图覆盖spark structtype的函数签名的方式存在问题,但这对我不起作用。
我来自python背景,在python作业中遇到一些性能问题,这就是为什么我决定在Scala中编写此map函数的原因。
答案 0 :(得分:2)
您必须使用Row类型而不是udf中的StructType。 StructType表示架构本身而不是数据。您可以使用Scala中的一个小示例:
object test{
import org.apache.spark.sql.functions.{udf, collect_list, struct}
val hash = HashMap[String, Float]("start_dt_local" -> 0)
// This simple type to store you results
val sampleDataset = Seq(Row(Instant.now().toEpochMilli, Instant.now().toEpochMilli))
implicit val spark: SparkSession =
SparkSession
.builder()
.appName("Test")
.master("local[*]")
.getOrCreate()
def calculateVariablesUdf = udf { trip: Row =>
if(trip.getAs[Long]("start_dt_local") >= trip.getAs[Long]("StartDtLocal")) {
// crate a new instance with your results
hash("start_dt_local") + 1
} else {
hash("start_dt_local") + 0
}
}
def main(args: Array[String]) : Unit = {
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
val rdd = spark.sparkContext.parallelize(sampleDataset)
val df = spark.createDataFrame(rdd, StructType(List(StructField("start_dt_local", LongType, false), StructField("StartDtLocal", LongType, false))))
df.agg(collect_list(calculateVariablesUdf(struct(col("start_dt_local"), col("StartDtLocal")))).as("result")).show(false)
}
}
编辑。为了更好的理解:
当您考虑架构描述时,您错了:StructType(List(StructField))作为字段的类型。您的DataFrame中没有列表类型。
如果将calculateVariables视为udf,则不需要for循环。我的意思是:
def calculateVariables = udf { trip: Row =>
trip("start_dt_local").getAs[Long]
// your logic ....
}
在我的示例中,您可以直接在udf中返回更新的哈希值