如何为Spark写一个数据帧的多个WHEN条件?

时间:2018-01-30 09:38:23

标签: scala apache-spark dataframe apache-spark-sql

我必须加入两个数据框并根据某些条件选择所有列。这是一个例子:

 val sqlContext = new org.apache.spark.sql.SQLContext(sc)

    import sqlContext.implicits._
    import org.apache.spark.{ SparkConf, SparkContext }
    import java.sql.{Date, Timestamp}
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions.udf

import org.apache.spark.sql.functions.input_file_name
import org.apache.spark.sql.functions.regexp_extract

val get_cus_val = sqlContext.udf.register("get_cus_val", (filePath: String) => filePath.split("\\.")(3))

val rdd = sc.textFile("s3://trfsmallfffile/FinancialLineItem/MAIN")
val header = rdd.filter(_.contains("LineItem.organizationId")).map(line => line.split("\\|\\^\\|")).first()
val schema = StructType(header.map(cols => StructField(cols.replace(".", "_"), StringType)).toSeq)
val data = sqlContext.createDataFrame(rdd.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schema)

val schemaHeader = StructType(header.map(cols => StructField(cols.replace(".", "."), StringType)).toSeq)
val dataHeader = sqlContext.createDataFrame(rdd.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schemaHeader)

val df1resultFinal=data.withColumn("DataPartition", get_cus_val(input_file_name))

val rdd1 = sc.textFile("s3://trfsmallfffile/FinancialLineItem/INCR")
val header1 = rdd1.filter(_.contains("LineItem.organizationId")).map(line => line.split("\\|\\^\\|")).first()
val schema1 = StructType(header1.map(cols => StructField(cols.replace(".", "_"), StringType)).toSeq)
val data1 = sqlContext.createDataFrame(rdd1.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schema1)


import org.apache.spark.sql.expressions._
val windowSpec = Window.partitionBy("LineItem_organizationId", "LineItem_lineItemId").orderBy($"TimeStamp".cast(LongType).desc) 
val latestForEachKey = data1.withColumn("rank", rank().over(windowSpec)).filter($"rank" === 1).drop("rank", "TimeStamp")


val dfMainOutput = df1resultFinal.join(latestForEachKey, Seq("LineItem_organizationId", "LineItem_lineItemId"), "outer")
      .select($"LineItem_organizationId", $"LineItem_lineItemId",
        when($"DataPartition_1".isNotNull, $"DataPartition_1").otherwise($"DataPartition").as("DataPartition"),
        when($"StatementTypeCode_1".isNotNull, $"StatementTypeCode_1").otherwise($"StatementTypeCode").as("StatementTypeCode"),
        when($"LineItemName_1".isNotNull, $"LineItemName_1").otherwise($"LineItemName").as("LineItemName"),
        when($"LocalLanguageLabel_1".isNotNull, $"LocalLanguageLabel_1").otherwise($"LocalLanguageLabel").as("LocalLanguageLabel"),
        when($"FinancialConceptLocal_1".isNotNull, $"FinancialConceptLocal_1").otherwise($"FinancialConceptLocal").as("FinancialConceptLocal"),
        when($"FinancialConceptGlobal_1".isNotNull, $"FinancialConceptGlobal_1").otherwise($"FinancialConceptGlobal").as("FinancialConceptGlobal"),
        when($"IsDimensional_1".isNotNull, $"IsDimensional_1").otherwise($"IsDimensional").as("IsDimensional"),
        when($"InstrumentId_1".isNotNull, $"InstrumentId_1").otherwise($"InstrumentId").as("InstrumentId"),
        when($"LineItemSequence_1".isNotNull, $"LineItemSequence_1").otherwise($"LineItemSequence").as("LineItemSequence"),
        when($"PhysicalMeasureId_1".isNotNull, $"PhysicalMeasureId_1").otherwise($"PhysicalMeasureId").as("PhysicalMeasureId"),
        when($"FinancialConceptCodeGlobalSecondary_1".isNotNull, $"FinancialConceptCodeGlobalSecondary_1").otherwise($"FinancialConceptCodeGlobalSecondary").as("FinancialConceptCodeGlobalSecondary"),
        when($"IsRangeAllowed_1".isNotNull, $"IsRangeAllowed_1").otherwise($"IsRangeAllowed").as("IsRangeAllowed"),
        when($"IsSegmentedByOrigin_1".isNotNull, $"IsSegmentedByOrigin_1").otherwise($"IsSegmentedByOrigin".cast(DataTypes.StringType)).as("IsSegmentedByOrigin"),
        when($"SegmentGroupDescription_1".isNotNull, $"SegmentGroupDescription_1").otherwise($"SegmentGroupDescription").as("SegmentGroupDescription"),
        when($"SegmentChildDescription_1".isNotNull, $"SegmentChildDescription_1").otherwise($"SegmentChildDescription").as("SegmentChildDescription"),
        when($"SegmentChildLocalLanguageLabel_1".isNotNull, $"SegmentChildLocalLanguageLabel_1").otherwise($"SegmentChildLocalLanguageLabel").as("SegmentChildLocalLanguageLabel"),
        when($"LocalLanguageLabel_languageId_1".isNotNull, $"LocalLanguageLabel_languageId_1").otherwise($"LocalLanguageLabel_languageId").as("LocalLanguageLabel_languageId"),
        when($"LineItemName_languageId_1".isNotNull, $"LineItemName_languageId_1").otherwise($"LineItemName_languageId").as("LineItemName_languageId"),
        when($"SegmentChildDescription_languageId_1".isNotNull, $"SegmentChildDescription_languageId_1").otherwise($"SegmentChildDescription_languageId").as("SegmentChildDescription_languageId"),
        when($"SegmentChildLocalLanguageLabel_languageId_1".isNotNull, $"SegmentChildLocalLanguageLabel_languageId_1").otherwise($"SegmentChildLocalLanguageLabel_languageId").as("SegmentChildLocalLanguageLabel_languageId"),
        when($"SegmentGroupDescription_languageId_1".isNotNull, $"SegmentGroupDescription_languageId_1").otherwise($"SegmentGroupDescription_languageId").as("SegmentGroupDescription_languageId"),
        when($"SegmentMultipleFundbDescription_1".isNotNull, $"SegmentMultipleFundbDescription_1").otherwise($"SegmentMultipleFundbDescription").as("SegmentMultipleFundbDescription"),
        when($"SegmentMultipleFundbDescription_languageId_1".isNotNull, $"SegmentMultipleFundbDescription_languageId_1").otherwise($"SegmentMultipleFundbDescription_languageId").as("SegmentMultipleFundbDescription_languageId"),
        when($"IsCredit_1".isNotNull, $"IsCredit_1").otherwise($"IsCredit").as("IsCredit"),
        when($"FinancialConceptLocalId_1".isNotNull, $"FinancialConceptLocalId_1").otherwise($"FinancialConceptLocalId").as("FinancialConceptLocalId"),
        when($"FinancialConceptGlobalId_1".isNotNull, $"FinancialConceptGlobalId_1").otherwise($"FinancialConceptGlobalId").as("FinancialConceptGlobalId"),
        when($"FinancialConceptCodeGlobalSecondaryId_1".isNotNull, $"FinancialConceptCodeGlobalSecondaryId_1").otherwise($"FinancialConceptCodeGlobalSecondaryId").as("FinancialConceptCodeGlobalSecondaryId"),
        when($"FFAction_1".isNotNull, $"FFAction_1").otherwise($"FFAction|!|").as("FFAction|!|"))
        .filter(!$"FFAction|!|".contains("D|!|"))

val dfMainOutputFinal = dfMainOutput.na.fill("").select($"DataPartition",$"StatementTypeCode",concat_ws("|^|", dfMainOutput.schema.fieldNames.filter(_ != "DataPartition").map(c => col(c)): _*).as("concatenated"))

val headerColumn = dataHeader.columns.toSeq

val header = headerColumn.mkString("", "|^|", "|!|").dropRight(3)

val dfMainOutputFinalWithoutNull = dfMainOutputFinal.withColumn("concatenated", regexp_replace(col("concatenated"), "|^|null", "")).withColumnRenamed("concatenated", header)


dfMainOutputFinalWithoutNull.write.partitionBy("DataPartition","StatementTypeCode")
  .format("csv")
  .option("nullValue", "")
  .option("delimiter", "\t")
  .option("quote", "\u0000")
  .option("header", "true")
  .option("codec", "gzip")
  .save("s3://trfsmallfffile/FinancialLineItem/output")

现在我必须明确写出所有列的条件。是否有任何方法不重复所有列的条件?

我的条件null列的值nullString。因此,应用coalesce可能会很困难。

这是数据框架一。

LineItem.organizationId|^|LineItem.lineItemId|^|StatementTypeCode|^|LineItemName|^|LocalLanguageLabel|^|FinancialConceptLocal|^|FinancialConceptGlobal|^|IsDimensional|^|InstrumentId|^|LineItemSequence|^|PhysicalMeasureId|^|FinancialConceptCodeGlobalSecondary|^|IsRangeAllowed|^|IsSegmentedByOrigin|^|SegmentGroupDescription|^|SegmentChildDescription|^|SegmentChildLocalLanguageLabel|^|LocalLanguageLabel.languageId|^|LineItemName.languageId|^|SegmentChildDescription.languageId|^|SegmentChildLocalLanguageLabel.languageId|^|SegmentGroupDescription.languageId|^|SegmentMultipleFundbDescription|^|SegmentMultipleFundbDescription.languageId|^|IsCredit|^|FinancialConceptLocalId|^|FinancialConceptGlobalId|^|FinancialConceptCodeGlobalSecondaryId|^|FFAction|!|
4295879842|^|1246|^|CUS|^|Net Sales-Customer Segment|^|相手先別の販売高(相手先別)|^|JCSNTS|^|REXM|^|False|^||^||^||^||^|False|^|False|^|CUS_JCSNTS|^||^||^|505126|^|505074|^|505074|^|505126|^|505126|^||^|505074|^|True|^|3020155|^|3015249|^||^|I|!|

这是我的数据框2。

DataPartition_1|^|TimeStamp|^|LineItem.organizationId|^|LineItem.lineItemId|^|StatementTypeCode_1|^|LineItemName_1|^|LocalLanguageLabel_1|^|FinancialConceptLocal_1|^|FinancialConceptGlobal_1|^|IsDimensional_1|^|InstrumentId_1|^|LineItemSequence_1|^|PhysicalMeasureId_1|^|FinancialConceptCodeGlobalSecondary_1|^|IsRangeAllowed_1|^|IsSegmentedByOrigin_1|^|SegmentGroupDescription_1|^|SegmentChildDescription_1|^|SegmentChildLocalLanguageLabel_1|^|LocalLanguageLabel.languageId_1|^|LineItemName.languageId_1|^|SegmentChildDescription.languageId_1|^|SegmentChildLocalLanguageLabel.languageId_1|^|SegmentGroupDescription.languageId_1|^|SegmentMultipleFundbDescription_1|^|SegmentMultipleFundbDescription.languageId_1|^|IsCredit_1|^|FinancialConceptLocalId_1|^|FinancialConceptGlobalId_1|^|FinancialConceptCodeGlobalSecondaryId_1|^|FFAction_1
SelfSourcedPublic|^|1511869196612|^|4295902451|^|10|^|BAL|^|Short term notes payable - related party|^|null|^|null|^|LSOD|^|false|^|null|^|null|^|null|^|null|^|false|^|false|^|null|^|null|^|null|^|null|^|505074|^|null|^|null|^|null|^|null|^|null|^|null|^|null|^|3019157|^|null|^|I|!|

这是我到目前为止所尝试的

println("Enterin In to Spark Mode ")

    val conf = new SparkConf().setAppName("FinanicalLineItem").setMaster("local");
    val sc = new SparkContext(conf); //Creating spark context
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)


    val mainFileURL = "C://Users//u6034690//Desktop//SPARK//trfsmallfffile//FinancialLineItem//MAIN"
    val incrFileURL = "C://Users//u6034690//Desktop//SPARK//trfsmallfffile//FinancialLineItem//INCR"
    val outputFileURL = "C://Users//u6034690//Desktop//SPARK//trfsmallfffile//FinancialLineItem//output"
    val descrFileURL = "C://Users//u6034690//Desktop//SPARK//trfsmallfffile//FinancialLineItem//Descr"

    val src = new Path(outputFileURL)
    val dest = new Path(mainFileURL)
    val hadoopconf = sc.hadoopConfiguration
    val fs = src.getFileSystem(hadoopconf)

    sc.hadoopConfiguration.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")

    sc.hadoopConfiguration.set("parquet.enable.summary-metadata", "false")

    myUtil.Utility.DeleteOuptuFolder(fs, outputFileURL)
    myUtil.Utility.DeleteDescrFolder(fs, descrFileURL)

    import sqlContext.implicits._

    val rdd = sc.textFile(mainFileURL)
    val header = rdd.filter(_.contains("LineItem.organizationId")).map(line => line.split("\\|\\^\\|")).first()
    val schema = StructType(header.map(cols => StructField(cols.replace(".", "_"), StringType)).toSeq)
    val data = sqlContext.createDataFrame(rdd.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schema)

    val schemaHeader = StructType(header.map(cols => StructField(cols.replace(".", "."), StringType)).toSeq)
    val dataHeader = sqlContext.createDataFrame(rdd.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schemaHeader)

    val get_cus_val = sqlContext.udf.register("get_cus_val", (filePath: String) => filePath.split("\\.")(3))

    val columnsNameArray = schema.fieldNames

    val df1resultFinal = data.withColumn("DataPartition", get_cus_val(input_file_name))
    val rdd1 = sc.textFile(incrFileURL)
    val header1 = rdd1.filter(_.contains("LineItem.organizationId")).map(line => line.split("\\|\\^\\|")).first()
    val schema1 = StructType(header1.map(cols => StructField(cols.replace(".", "_"), StringType)).toSeq)
    val data1 = sqlContext.createDataFrame(rdd1.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schema1)

    val windowSpec = Window.partitionBy("LineItem_organizationId", "LineItem_lineItemId").orderBy($"TimeStamp".cast(LongType).desc)
    val latestForEachKey = data1.withColumn("rank", rank().over(windowSpec)).filter($"rank" === 1).drop("rank", "TimeStamp")

    val columnMap = latestForEachKey.columns
      .filter(_.endsWith("_1"))
      .map(c => c -> c.dropRight(2))
      .toMap + ("FFAction_1" -> "FFAction|!|")


        val exprs = columnMap.map(t => coalesce(col(s"${t._1}"), col(s"${t._2}")).as(s"${t._2}"))
        val exprsExtended = exprs ++ Array(col("LineItem_organizationId"), col("LineItem_lineItemId"))
        println(exprsExtended)
        val df2 = data.select(exprsExtended: _*)//This line has compilation issue .

type mismatch; found : scala.collection.immutable.Iterable[org.apache.spark.sql.Column] required: Seq[?]

当我打印exprsExtended时,我的输出列中出现了``

coalesce(LineItemSequence_1, LineItemSequence) AS `LineItemSequence`,

1 个答案:

答案 0 :(得分:3)

第一步是在所有when子句中创建一个包含列名的元组列表。它可以通过多种方式完成,但如果要使用数据框中的所有列,则可以按如下方式完成(使用示例数据帧):

val df = Seq(("1", "2", null, "4", "5", "6"), 
    (null, "2", "3", "4", null, "6"), 
    ("1", "2", "3", "4", null, "6"))
  .toDF("col1_1", "col1", "col2_1", "col2", "col3_1", "col3|!|")

val columnMap = df.columns.grouped(2).map(a => (a(0), a(1))).toArray

现在columnMap变量包含要用作元组的列:

("col1_1", "col1")
("col1_2", "col2")
("col1_3", "col3|!|")

下一步是构建一个可以使用select变量在columnMap语句中使用的表达式:

val exprs = columnMap.map(t => coalesce(col(s"${t._1}"), col(s"${t._2}")).as(s"${t._2}"))

并将表达式应用于数据帧:

val df2 = df.select(exprs:_*)

最终结果如下:

+----+----+-------+
|col1|col2|col3|!||
+----+----+-------+
|   1|   4|      5|
|   2|   3|      6|
|   1|   3|      6|
+----+----+-------+

注意:如果除exprs变量中的列之外还应选择其他列,只需按以下方式添加:

val exprsExtended = exprs ++ Array(col("other_column1), col("other_column2"))
val df2 = df.select(exprsExtended :_*)

修改:要在此特定情况下创建columnMap,使用这样的列名,从带有_1后缀的所有列开始看起来最简单。在join之前,从latestForEachKey数据框中获取这些列:

val columnMap = latestForEachKey.columns 
  .filter(c => c.endsWith("_1") & c != "FFAction_1") 
  .map(c => c -> c.dropRight(2)) :+ ("FFAction_1", "FFAction|!|")

然后如上所述创建并使用exprsexprsExtended

val exprs = columnMap.map(t => coalesce(col(s"${t._1}"), col(s"${t._2}")).as(s"${t._2}"))
val exprsExtended = exprs ++ Array(col("LineItem_organizationId"), col("LineItem_lineItemId"))
val df2 = df.select(exprsExtended:_*)