来自spark数据帧中列表列的示例值

时间:2017-04-08 20:50:27

标签: scala list spark-dataframe

我有一个spark-scala数据框,如下面的df1所示:我想根据df1的另一列中的计数,从得分列(列表)中进行替换。

val df1 = sc.parallelize(Seq(("a1",2,List(20,10)),("a2",1,List(30,10)),
("a3",3,List(10)),("a4",2,List(10,20,40)))).toDF("colA","counts","scores")

df1.show()
+----+------+------------+
|colA|counts|      scores|
+----+------+------------+
|  a1|     2|    [20, 10]|
|  a2|     1|    [30, 10]|
|  a3|     3|        [10]|
|  a4|     2|[10, 20, 40]|
+----+------+------------+

预期输出显示在df2中:来自第1行,来自列表[20,10]的样本2值;来自第2行样本1来自列表[30,10]的值;来自第3行的样本3来自列表[10]的值,重复等等。

df2.show() //expected output
+----+------+------------+-------------+
|colA|counts|      scores|sampledScores|
+----+------+------------+-------------+
|  a1|     2|    [20, 10]|     [20, 10]|
|  a2|     1|    [30, 10]|         [30]|
|  a3|     3|        [10]| [10, 10, 10]|
|  a4|     2|[10, 20, 40]|     [10, 40]|
+----+------+------------+-------------+

我写了一篇udf' takeSample'并应用于df1,但没有按预期工作。

val takeSample = udf((a:Array[Int], count1:Int) => {Array.fill(count1)(
a(new Random(System.currentTimeMillis).nextInt(a.size)))}
)

val df2 = df1.withColumn("SampledScores", takeSample(df1("Scores"),df1("counts")))

我遇到了以下运行时错误;执行时

df2.printSchema()
root
 |-- colA: string (nullable = true)
 |-- counts: integer (nullable = true)
 |-- scores: array (nullable = true)
 |    |-- element: integer (containsNull = false)
 |-- SampledScores: array (nullable = true)
 |    |-- element: integer (containsNull = false)

df2.show()
org.apache.spark.SparkException: Failed to execute user defined   
function($anonfun$1: (array<int>, int) => array<int>)
Caused by: java.lang.ClassCastException:  
scala.collection.mutable.WrappedArray$ofRef cannot be cast to [I
at $anonfun$1.apply(<console>:47)

非常感谢任何解决方案。

1 个答案:

答案 0 :(得分:1)

在UDF中将数据类型从Array[Int]更改为Seq[Int]将解决此问题:

val takeSample = udf((a:Seq[Int], count1:Int) => {Array.fill(count1)(
a(new Random(System.currentTimeMillis).nextInt(a.size)))}
)

val df2 = df1.withColumn("SampledScores", takeSample(df1("Scores"),df1("counts")))

它会给我们预期的输出:

df2.printSchema()
root
 |-- colA: string (nullable = true)
 |-- counts: integer (nullable = true)
 |-- scores: array (nullable = true)
 |    |-- element: integer (containsNull = false)
 |-- SampledScores: array (nullable = true)
 |    |-- element: integer (containsNull = false)

df2.show
+----+------+------------+-------------+
|colA|counts|      scores|SampledScores|
+----+------+------------+-------------+
|  a1|     2|    [20, 10]|     [20, 20]|
|  a2|     1|    [30, 10]|         [30]|
|  a3|     3|        [10]| [10, 10, 10]|
|  a4|     2|[10, 20, 40]|     [20, 20]|
+----+------+------------+-------------+