在Spark中使用UDF时出现NullPointerException

时间:2017-03-30 11:05:56

标签: scala apache-spark dataframe user-defined-functions udf

我在Spark中有一个DataFrame,例如:

var df = List(
  (1,"{NUM.0002}*{NUM.0003}"),
  (2,"{NUM.0004}+{NUM.0003}"),
  (3,"END(6)"),
  (4,"END(4)")
).toDF("CODE", "VALUE")

+----+---------------------+
|CODE|                VALUE|
+----+---------------------+
|   1|{NUM.0002}*{NUM.0003}|
|   2|{NUM.0004}+{NUM.0003}|
|   3|               END(6)|
|   4|               END(4)|
+----+---------------------+

我的任务是遍历VALUE列并执行以下操作:检查是否存在{NUM.XXXX}等子字符串,获取XXXX号码,获取$" CODE" === XXXX,并将{NUM.XXX}子字符串替换为该行中的VALUE字符串。

我希望数据框最终看起来像这样:

+----+--------------------+
|CODE|               VALUE|
+----+--------------------+
|   1|END(4)+END(6)*END(6)|
|   2|       END(4)+END(6)|
|   3|              END(6)|
|   4|              END(4)|
+----+--------------------+

这是我提出的最佳选择:

val process = udf((ln: String) => {
  var newln = ln
  while(newln contains "{NUM."){
    var num = newln.slice(newln.indexOf("{")+5, newln.indexOf("}")).toInt 
    var new_value = df.where($"CODE" === num).head.getAs[String](1)
    newln = newln.replace(newln.slice(newln.indexOf("{"),newln.indexOf("}")+1), new_value)
  }
  newln
})

var df2 = df.withColumn("VALUE", when('VALUE contains "{NUM.",process('VALUE)).otherwise('VALUE))

不幸的是,当我尝试过滤/选择/保存df2时,我收到NullPointerException,当我只显示df2时没有错误。我相信当我访问UDF中的DataFrame df时会出现错误,但我需要在每次迭代时访问它,因此我无法将其作为输入传递。另外,我尝试在UDF中保存df的副本,但我不知道该怎么做。我能在这做什么?

非常欢迎任何改进算法的建议!谢谢!

1 个答案:

答案 0 :(得分:1)

我写了一些有效但不太优化的东西。我实际上在初始DataFrame上进行递归连接以将END替换为NUM。这是代码:

    case class Data(code: Long, value: String)

    def main(args: Array[String]): Unit = {
        val sparkSession: SparkSession = SparkSession.builder().master("local").getOrCreate()

        val data = Seq(
            Data(1,"{NUM.0002}*{NUM.0003}"),
            Data(2,"{NUM.0004}+{NUM.0003}"),
            Data(3,"END(6)"),
            Data(4,"END(4)"),
            Data(5,"{NUM.0002}")
        )

        val initialDF = sparkSession.createDataFrame(data)
        val endDF = initialDF.filter(!(col("value") contains "{NUM"))
        val numDF = initialDF.filter(col("value") contains "{NUM")

        val resultDF = endDF.union(replaceNumByEnd(initialDF, numDF))
        resultDF.show(false)
    }


    val parseNumUdf = udf((value: String) => {
        if (value.contains("{NUM")) {
            val regex = """.*?\{NUM\.(\d+)\}.*""".r
            value match {
                case regex(code) => code.toLong
            }
        } else {
            -1L
        }
    })

    val replaceUdf = udf((value: String, replacement: String) => {
        val regex = """\{NUM\.(\d+)\}""".r
        regex.replaceFirstIn(value, replacement)
    })

    def replaceNumByEnd(initialDF: DataFrame, currentDF: DataFrame): DataFrame = {
        if (currentDF.count() == 0) {
            currentDF
        } else {
            val numDFWithCode = currentDF
                .withColumn("num_code", parseNumUdf(col("value")))
                .withColumnRenamed("code", "code_original")
                .withColumnRenamed("value", "value_original")

            val joinedDF = numDFWithCode.join(initialDF, numDFWithCode("num_code") === initialDF("code"))

            val replacedDF = joinedDF.withColumn("value_replaced", replaceUdf(col("value_original"), col("value")))

            val nextDF = replacedDF.select(col("code_original").as("code"), col("value_replaced").as("value"))

            val endDF = nextDF.filter(!(col("value") contains "{NUM"))
            val numDF = nextDF.filter(col("value") contains "{NUM")

            endDF.union(replaceNumByEnd(initialDF, numDF))
        }
    }

如果您需要更多解释,请不要犹豫。