假设我有一个DataFrame:
val testDf = sc.parallelize(Seq(
(1,2,"x", Array(1,2,3,4)))).toDF("one", "two", "X", "Array")
+---+---+---+------------+
|one|two| X| Array|
+---+---+---+------------+
| 1| 2| x|[1, 2, 3, 4]|
+---+---+---+------------+
我想复制单个元素,比方说4次,以实现单行DataFrame,每个字段为四个元素的数组。所需的输出是:
+------------+------------+------------+------------+
| one| two| X| Array|
+------------+------------+------------+------------+
|[1, 1, 1, 1]|[2, 2, 2, 2]|[x, x, x, x]|[1, 2, 3, 4]|
+------------+------------+------------+------------+
答案 0 :(得分:1)
嗯,这是我的解决方案:
首先声明要复制的列:
val columnsToReplicate = List("one", "two", "X")
然后定义复制因子和udf来执行它:
val replicationFactor = 4
val replicate = (s:String) => {
for {
i <- 1 to replicationFactor
} yield s
}
val replicateudf = functions.udf(replicate)
当columname属于您想要的列名列表时,只需在DataFrame
上执行foldLeft:
testDf.columns.foldLeft(testDf)((acc, colname) => if (columnsToReplicate.contains(colname)) acc.withColumn(colname, replicateudf(acc.col(colname))) else acc)
输出:
+------------+------------+------------+------------+
| one| two| X| Array|
+------------+------------+------------+------------+
|[1, 1, 1, 1]|[2, 2, 2, 2]|[x, x, x, x]|[1, 2, 3, 4]|
+------------+------------+------------+------------+
注意:您需要导入此类:
import org.apache.spark.sql.functions
修改强>
注释中建议的变量replicationFactor:
val mapColumnsToReplicate = Map("one"->4, "two"->5, "X"->6)
val replicateudf2 = functions.udf ((s: String, replicationFactor: Int) =>
for {
i <- 1 to replicationFactor
} yield s
)
testDf.columns.foldLeft(testDf)((acc, colname) => if (mapColumnsToReplicate.keys.toList.contains(colname)) acc.withColumn(colname, replicateudf2($"$colname", functions.lit(mapColumnsToReplicate(colname))))` else acc)
使用上述值输出:
+------------+---------------+------------------+------------+
| one| two| X| Array|
+------------+---------------+------------------+------------+
|[1, 1, 1, 1]|[2, 2, 2, 2, 2]|[x, x, x, x, x, x]|[1, 2, 3, 4]|
+------------+---------------+------------------+------------+
答案 1 :(得分:1)
您可以使用内置array
功能复制您选择的n
时间列。
以下是PoC代码。
import org.apache.spark.sql.functions._
val replicate = (n: Int, colName: String) => array((1 to n).map(s => col(colName)):_*)
val replicatedCol = Seq("one", "two", "X").map(s => replicate(4, s).as(s))
val cols = col("Array") +: replicatedCol
val testDf = sc.parallelize(Seq(
(1,2,"x", Array(1,2,3,4)))).toDF("one", "two", "X", "Array").select(cols:_*)
testDf.show(false)
+------------+------------+------------+------------+
|Array |one |two |X |
+------------+------------+------------+------------+
|[1, 2, 3, 4]|[1, 1, 1, 1]|[2, 2, 2, 2]|[x, x, x, x]|
+------------+------------+------------+------------+
在这种情况下,您希望每列有不同的n
val testDf = sc.parallelize(Seq(
(1,2,"x", Array(1,2,3,4)))).toDF("one", "two", "X", "Array").select(replicate(2, "one").as("one"), replicate(3, "X").as("X"), replicate(4, "two").as("two"), $"Array")
testDf.show(false)
+------+---------+------------+------------+
|one |X |two |Array |
+------+---------+------------+------------+
|[1, 1]|[x, x, x]|[2, 2, 2, 2]|[1, 2, 3, 4]|
+------+---------+------------+------------+
答案 2 :(得分:0)
您可以使用explode
和groupBy
/ collect_list
:
val testDf = sc.parallelize(
Seq((1, 2, "x", Array(1, 2, 3, 4)),
(3, 4, "y", Array(1, 2, 3)),
(5,6, "z", Array(1)))
).toDF("one", "two", "X", "Array")
testDf
.withColumn("id",monotonically_increasing_id())
.withColumn("tmp", explode($"Array"))
.groupBy($"id")
.agg(
collect_list($"one").as("cl_one"),
collect_list($"two").as("cl_two"),
collect_list($"X").as("cl_X"),
first($"Array").as("Array")
)
.select(
$"cl_one".as("one"),
$"cl_two".as("two"),
$"cl_X".as("X"),
$"Array"
)
.show()
+------------+------------+------------+------------+
| one| two| X| Array|
+------------+------------+------------+------------+
| [5]| [6]| [z]| [1]|
|[1, 1, 1, 1]|[2, 2, 2, 2]|[x, x, x, x]|[1, 2, 3, 4]|
| [3, 3, 3]| [4, 4, 4]| [y, y, y]| [1, 2, 3]|
+------------+------------+------------+------------+
此解决方案的优势在于它不依赖于常量数组大小