在pyspark中使用数组类型列

时间:2019-07-19 13:47:43

标签: python-3.x apache-spark pyspark

我有一个pyspark数据框,其中包含多个数组类型的列,其值作为Long类型。这里以几列为例。我想将此数据帧另存为csv文件。在将列强制转换为“字符串”时,因此对于列中的值,我得到的是“ org.apache.spark.sql.catalyst.expressions.UnsafeArrayData@d5d9fa73”。有人可以帮我吗?

             |-- VoltageMin: array (nullable = true)
                |-- element: long (containsNull = true)
             |-- Temp: array (nullable = true)
                |-- element: long (containsNull = true)
             |-- Speed: array (nullable = true)
                |-- element: long (containsNull = true)
             |-- PowerConsumption: array (nullable = true)
                |-- element: long (containsNull = true)
             |-- VoltageMax: array (nullable = true)
                |-- element: long (containsNull = true)

2 个答案:

答案 0 :(得分:0)

您可以使用explode()功能。通常用于将具有多个值的数组分解为具有该数组中单个值的多行(复制所有其他列)。

就您而言,您可以df = df.withColumn("Temp", explode("Temp")) 并对每一列重复一次,就可以解决问题。

答案 1 :(得分:0)

由于CSV是一种基于文本的格式,其本身并不支持复杂类型,因此明智的做法是使用array_join将数组存储为非逗号分隔的字符串。对于更复杂的内容,将其存储为JSON字符串可能更合适,但是对于简单的数组,以下内容(使用|作为数组的分隔符)就足够了:

scala> val input = spark.range(100)
input: org.apache.spark.sql.Dataset[Long] = [id: bigint]

scala> val tst = input.map(x => (x / 6, x / 5, x / 4, x / 3, x / 2, x))
tst: org.apache.spark.sql.Dataset[(Long, Long, Long, Long, Long, Long)] = [_1: bigint, _2: bigint ... 4 more fields]

scala> val with_arrays = tst.groupBy('_1.as("id")).agg(collect_list('_2).as("VoltageMin"), collect_list('_3).as("Temp"), collect_list('_4).as("Speed"), collect_list('_5).as("PowerConsumption"), collect_list('_6).as("VoltageMax"))
with_arrays: org.apache.spark.sql.DataFrame = [id: bigint, VoltageMin: array<bigint> ... 4 more fields]

scala> val arrayCols = with_arrays.schema.map(_.name).drop(1)
arrayCols: Seq[String] = List(VoltageMin, Temp, Speed, PowerConsumption, VoltageMax)

scala> val flat = arrayCols.foldLeft(with_arrays){(df, field) => df.withColumn(field, array_join(df(field), "|"))}
flat: org.apache.spark.sql.DataFrame = [id: bigint, VoltageMin: string ... 4 more fields]

scala> flat.show(5, false)
+---+-----------------+-----------------+-----------------+-----------------+-----------------+
|id |VoltageMin       |Temp             |Speed            |PowerConsumption |VoltageMax       |
+---+-----------------+-----------------+-----------------+-----------------+-----------------+
|0  |0|0|0|0|0|1      |0|0|0|0|1|1      |0|0|0|1|1|1      |0|0|1|1|2|2      |0|1|2|3|4|5      |
|7  |8|8|8|9|9|9      |10|10|11|11|11|11|14|14|14|15|15|15|21|21|22|22|23|23|42|43|44|45|46|47|
|6  |7|7|7|7|8|8      |9|9|9|9|10|10    |12|12|12|13|13|13|18|18|19|19|20|20|36|37|38|39|40|41|
|9  |10|11|11|11|11|11|13|13|14|14|14|14|18|18|18|19|19|19|27|27|28|28|29|29|54|55|56|57|58|59|
|5  |6|6|6|6|6|7      |7|7|8|8|8|8      |10|10|10|11|11|11|15|15|16|16|17|17|30|31|32|33|34|35|
+---+-----------------+-----------------+-----------------+-----------------+-----------------+
only showing top 5 rows


scala> flat.repartition(5).write.mode("overwrite").csv("csv_test")

scala> :quit
➜  ~ head csv_test/part-00000-ad853063-19d9-47cc-bc9c-b3cfa1697638-c000.csv 
6,7|7|7|7|8|8,9|9|9|9|10|10,12|12|12|13|13|13,18|18|19|19|20|20,36|37|38|39|40|41
9,10|11|11|11|11|11,13|13|14|14|14|14,18|18|18|19|19|19,27|27|28|28|29|29,54|55|56|57|58|59
12,14|14|14|15|15|15,18|18|18|18|19|19,24|24|24|25|25|25,36|36|37|37|38|38,72|73|74|75|76|77
2,2|2|2|3|3|3,3|3|3|3|4|4,4|4|4|5|5|5,6|6|7|7|8|8,12|13|14|15|16|17
14,16|17|17|17|17|17,21|21|21|21|22|22,28|28|28|29|29|29,42|42|43|43|44|44,84|85|86|87|88|89
15,18|18|18|18|18|19,22|22|23|23|23|23,30|30|30|31|31|31,45|45|46|46|47|47,90|91|92|93|94|95

再次读入数据时,可以使用split函数再次将以竖线分隔的字符串转换为数组。