我想转换一个表示数组中List<List<Long,Float,Float,Integer,Integer>>
的字符串。
为了实现这一点,我使用具有以下结构的UDF函数:
字符串的一个示例是[[337, -115.0, -17.5, 6225, 189],[85075, -112.0, -12.5, 6225, 359]]
def convertToListOfListComplex(ListOfList: String, regex: String): Array[StructType]
={
val notBracket = ListOfList.dropRight(1).drop(1)
val SplitString = notBracket.split("]").map(x=>if (x.startsWith("[")) x.drop(1) else x.drop(2))
SplitString(0).replaceAll("\\s", "")
val result =SplitString map {
case s => {
val split = s.replaceAll("\\s", "").trim.split(",")
case class Row(a: Long, b: Float, c: Float, d: Int, e: Int)
val element = Row(split(0).toLong, split(1).toFloat, split(2).toFloat, split(3).toInt, split(4).toInt)
val schema = `valid code to transform to case class to StructType`
}
}
return result
}
我正在使用Spark 2.2。 我尝试了不同的解决方案,但发现获取StructTypes数组的问题,获取编译错误或执行失败。有什么建议吗?
答案 0 :(得分:2)
出于测试目的,我创建了一个测试数据帧,问题中提到的字符串为
val df = Seq(
Tuple1("[[337, -115.0, -17.5, 6225, 189],[85075, -112.0, -12.5, 6225, 359]]")
).toDF("col")
是
+-------------------------------------------------------------------+
|col |
+-------------------------------------------------------------------+
|[[337, -115.0, -17.5, 6225, 189],[85075, -112.0, -12.5, 6225, 359]]|
+-------------------------------------------------------------------+
root
|-- col: string (nullable = true)
udf
函数应如下所示
import org.apache.spark.sql.functions._
def convertToListOfListComplex = udf((ListOfList: String) => {
ListOfList.split("],\\[")
.map(x => x.replaceAll("[\\]\\[]", "").split(","))
.map(splitted => rowTest(splitted(0).trim.toLong, splitted(1).trim.toFloat, splitted(2).trim.toFloat, splitted(3).trim.toInt, splitted(4).trim.toInt))
})
其中rowTest
是case class
,在范围之外定义为
case class rowTest(a: Long, b: Float, c: Float, d: Int, e: Int)
并调用udf
函数
df.withColumn("converted", convertToListOfListComplex(col("col")))
应将您的输出显示为
+-------------------------------------------------------------------+--------------------------------------------------------------------+
|col |converted |
+-------------------------------------------------------------------+--------------------------------------------------------------------+
|[[337, -115.0, -17.5, 6225, 189],[85075, -112.0, -12.5, 6225, 359]]|[[337, -115.0, -17.5, 6225, 189], [85075, -112.0, -12.5, 6225, 359]]|
+-------------------------------------------------------------------+--------------------------------------------------------------------+
root
|-- col: string (nullable = true)
|-- converted: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- a: long (nullable = false)
| | |-- b: float (nullable = false)
| | |-- c: float (nullable = false)
| | |-- d: integer (nullable = false)
| | |-- e: integer (nullable = false)
为了更加安全起见,您可以将Try/getOrElse
函数中的udf
用作
import org.apache.spark.sql.functions._
def convertToListOfListComplex = udf((ListOfList: String) => {
ListOfList.split("],\\[")
.map(x => x.replaceAll("[\\]\\[]", "").split(","))
.map(splitted => rowTest(Try(splitted(0).trim.toLong).getOrElse(0L), Try(splitted(1).trim.toFloat).getOrElse(0F), Try(splitted(2).trim.toFloat).getOrElse(0F), Try(splitted(3).trim.toInt).getOrElse(0), Try(splitted(4).trim.toInt).getOrElse(0)))
})
我希望答案会有所帮助