在数据框列中填充嵌套数组

时间:2020-06-11 13:53:54

标签: scala apache-spark apache-spark-sql

我有一个具有以下结构的数据框:

a

数组列必须保存两个元素(数组),从该元素创建的元素不丢失。 例如,我有这个:

x

期望的结果是:

 |-- col0: double (nullable = true)
 |-- arr: array (nullable = true)
 |    |-- element: array (containsNull = true)
 |    |    |-- element: double (containsNull = false)

因此,当数组具有2个元素时,无需执行任何操作。但是,如果有一个元素,则需要创建第二个元素,而该元素的值丢失(如果有一个元素的值为0.0,那么我需要创建一个值为[1.0,0.0]的元素,并且如果有一个元素的值为0.0 ,我需要[0.0,0.0])。

我尝试了以下方法,但是没有用:

|0.0 |[[0.0, 182.0], [1.0, 14.0]]|
|0.0 |[[1.0, 60.0]]              |
|1.0 |[[0.0, 3.0], [1.0, 48.0]]  |
|2.0 |[[1.0, 6.0], [0.0, 111.0]] |
|0.0 |[[1.0, 4.0], [0.0, 120.0]] |
|2.0 |[[0.0, 21.0]]              |
|0.0 |[[0.0, 3.0], [1.0, 13.0]]  |

错误是:

|0.0 |[[0.0, 182.0], [1.0, 14.0]]|
|0.0 |[[0.0, 0.0], [1.0, 60.0]]  |
|1.0 |[[0.0, 3.0], [1.0, 48.0]]  |
|2.0 |[[0.0, 111.0], [1.0, 6.0]] |
|0.0 |[[0.0, 120.0], [1.0, 4.0]] |
|2.0 |[[0.0, 21.0], [1.0, 0.0]]  |
|0.0 |[[0.0, 3.0], [1.0, 13.0]]  |

1 个答案:

答案 0 :(得分:1)

不要将UDF的输入定义为Array,而是将其定义为Seq,您应该会很好:

val headValue = udf((arr: Seq[Seq[Double]], maxValue: Double, minValue: Double) => {
  val flatArr = arr.flatMap(_.headOption)
  val nArr = arr
  if (flatArr.length == 1){
    if (flatArr.head == maxValue){
      nArr :+  Seq(minValue, 0.0)
    } else {
      nArr :+  Seq(maxValue, 0.0)
    }
  } else {
    nArr
  }
})