在Spark中的网状结构中递归重命名列

时间:2018-07-13 17:10:45

标签: scala apache-spark apache-spark-sql

我正在尝试替换DataFrame的所有列中具有某些嵌套的Struct类型的某些字符。

我试图递归处理架构字段,由于某种原因,即使到达叶节点,它也只是在顶层重命名字段。

我正在尝试将列名中的':'字符替换为'_'

这是我写的scala代码。

class UpdateSchema {

  val logger = LoggerFactory.getLogger(classOf[UpdateSchema])

  Logger.getLogger("org").setLevel(Level.OFF)

  Logger.getLogger("akka").setLevel(Level.OFF)

  val sparkSession = SparkLauncher.spark

  import sparkSession.implicits._   

  def updateSchema(filePath:String):Boolean ={
    logger.info(".updateSchema() : filePath ={}",filePath);
    logger.info(".updateSchema() : sparkSession ={}",sparkSession);
    if(sparkSession!=null){
      var xmlDF = sparkSession
                  .read
                  .format("com.databricks.spark.xml")
                  .option("rowTag","ns:fltdMessage")
                  .option("inferschema","true")
                  .option("attributePrefix","attr_")
                  .load(filePath)
                  .toDF()

      xmlDF.printSchema()
      val updatedDF = renameDataFrameColumns(xmlDF.toDF()) 
      updatedDF.printSchema()
    }
    else
      logger.info(".updateSchema(): Spark Session is NULL !!!");
    false;
  }


    def replaceSpecialChars(str:String):String ={
          val newColumn:String =  str.replaceAll(":", "_")
          //logger.info(".replaceSpecialChars() : Old Column Name =["+str+"] New Column Name =["+newColumn+"]")
          return newColumn
      }

      def renameColumn(df:DataFrame,colName:String,prefix:String):DataFrame ={
        val newColuName:String = replaceSpecialChars(colName)
        logger.info(".renameColumn(): prefix=["+prefix+"] colName=["+colName+"] New Column Name=["+newColuName+"]")
        if(prefix.equals("")){
          if(df.col(colName)!=null){
            return df.withColumnRenamed(colName, replaceSpecialChars(colName))
          }
          else{
            logger.error(".logSchema() : Column ["+prefix+"."+colName+"] Not found in DataFrame !! ")
            logger.info("Prefix ="+prefix+" Existing Columns =["+df.columns.mkString("),(")+"]")
            throw new Exception("Unable to find Column ["+prefix+"."+colName+"]")
          }
        }
        else{
          if(df.col(prefix+"."+colName)!=null){
            return df.withColumnRenamed(prefix+"."+colName, prefix+"."+replaceSpecialChars(colName))
          }
          else{
            logger.error(".logSchema() : Column ["+prefix+"."+colName+"] Not found in DataFrame !! ")
            logger.info("Prefix ="+prefix+" Existing Columns =["+df.columns.mkString("),(")+"]")
            throw new Exception("Unable to find Column ["+prefix+"."+colName+"]")
          }
        }
      }

      def getStructType(schema:StructType,fieldName:String):StructType = {
        schema.fields.foreach(field => {
              field.dataType match{
                case st:StructType => {
                  logger.info(".getStructType(): Current Field Name =["+field.name.toString()+"] Checking for =["+fieldName+"]")
                  if(field.name.toString().equals(fieldName)){
                    return field.dataType.asInstanceOf[StructType]
                  }
                  else{
                    getStructType(st,fieldName)
                  }
                }
                case _ =>{
                  logger.info(".getStructType(): Non Struct Type. Ignoring Filed=["+field.name.toString()+"]");
                }
              }
          })
          throw new Exception("Unable to find Struct Type for filed Name["+fieldName+"]")
      }

      def processSchema(df:DataFrame,schema:StructType,prefix:String):DataFrame ={
        var updatedDF:DataFrame =df
        schema.fields.foreach(field =>{
          field.dataType match {
            case st:StructType => {
                logger.info(".processSchema() : Struct Type =["+st+"]");
                logger.info(".processSchema() : Field Data Type =["+field.dataType+"]");
                logger.info(".processSchema() : Renaming the Struct Field =["+field.name.toString()+"] st=["+st.fieldNames.mkString(",")+"]") 
                updatedDF = renameColumn(updatedDF,field.name.toString(),prefix)
                logger.info(".processSchema() : Column List after Rename =["+updatedDF.columns.mkString(",")+"]")
               // updatedDF.schema.fields.foldLeft(z)(op)
                val renamedCol:String = replaceSpecialChars(field.name.toString())
                var fieldType:DataType = null;
                //if(prefix.equals(""))
                fieldType = schema.fields.find(f =>{ (f.name.toString().equals(field.name.toString()))}).get.dataType

                if(prefix.trim().equals("") 
                    //&& fieldType.isInstanceOf[StructType]
                    ){
                  updatedDF = processSchema(updatedDF,
                      getStructType(updatedDF.schema,renamedCol),
                      replaceSpecialChars(field.name.toString()))
                }
                else{
                  updatedDF = processSchema(updatedDF,
                      getStructType(updatedDF.schema,renamedCol),
                      prefix+"."+replaceSpecialChars(field.name.toString()))
                }
              }
            case _ => {
              updatedDF = renameColumn(updatedDF,field.name.toString(),prefix)
            }
          }
        })
        //updatedDF.printSchema()


        return updatedDF
      }

      def renameDataFrameColumns(df:DataFrame):DataFrame ={
        val schema = df.schema;
        return processSchema(df,schema,"")
      }
}

2 个答案:

答案 0 :(得分:1)

不幸的是,您无法像尝试那样使用withFieldRenamed轻松地重命名单个嵌套字段。我知道重命名嵌套字段的唯一方法是在提供具有相同结构和数据类型但具有新字段名称的类型的字段上进行强制类型转换。这必须在顶级字段上完成,因此您需要一次性完成所有字段。这是一个示例:

创建一些输入数据

case class InnerRecord(column1: String, column2: Int)
case class Record(field: InnerRecord)

val df = Seq(
    Record(InnerRecord("a", 1)),
    Record(InnerRecord("b", 2))
).toDF

df.printSchema

输入数据如下:

root
 |-- field: struct (nullable = true)
 |    |-- column1: string (nullable = true)
 |    |-- column2: integer (nullable = false)

这是使用withColumnRenamed的示例。您会在输出中注意到它实际上没有任何作用!

val updated = df.withColumnRenamed("field.column1", "field.newname")
updated.printSchema

root
 |-- field: struct (nullable = true)
 |    |-- column1: string (nullable = true)
 |    |-- column2: integer (nullable = false)

在这里,您可以通过铸造来做到这一点。该函数将在更新名称时递归地重新创建嵌套的字段类型。就我而言,我只是用“ col_”代替了“ column”。我也只在一个字段上运行它,但是您可以轻松地在架构中的所有字段之间循环。

import org.apache.spark.sql.types._

def rename(dataType: DataType): DataType = dataType match {
    case StructType(fields) => 
        StructType(fields.map { 
            case StructField(name, dtype, nullable, meta) => 
                val newName = name.replace("column", "col_")
                StructField(newName, rename(dtype), nullable, meta)
        })

    case _ => dataType
}


val fDataType = df.schema.filter(_.name == "field").head.dataType
val updated = df.withColumn("field", $"field".cast(rename(fDataType)))
updated.printSchema

哪些印刷品:

root
 |-- field: struct (nullable = true)
 |    |-- col_1: string (nullable = true)
 |    |-- col_2: integer (nullable = false)

答案 1 :(得分:1)

这是一种递归方法,可通过replaceAll重命名其名称包含要替换的子字符串的任何列来修改DataFrame模式:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

def renameAllColumns(schema: StructType, from: String, to:String): StructType = {

  def recurRename(schema: StructType, from: String, to:String): Seq[StructField] =
    schema.fields.map{
      case StructField(name, dtype: StructType, nullable, meta) =>
        StructField(name.replaceAll(from, to), StructType(recurRename(dtype, from, to)), nullable, meta)
      case StructField(name, dtype, nullable, meta) =>
        StructField(name.replaceAll(from, to), dtype, nullable, meta)
    }

  StructType(recurRename(schema, from, to))
}

在具有嵌套结构的示例DataFrame上测试方法:

case class M(i: Int, `p:q`: String)
case class N(j: Int, m: M)

val df = Seq(
  (1, "a", N(7, M(11, "x"))),
  (2, "b", N(8, M(21, "y"))),
  (3, "c", N(9, M(31, "z")))
).toDF("c1", "c2:0", "c3")

df.printSchema
// root
//  |-- c1: integer (nullable = false)
//  |-- c2:0: string (nullable = true)
//  |-- c3: struct (nullable = true)
//  |    |-- j: integer (nullable = false)
//  |    |-- m: struct (nullable = true)
//  |    |    |-- i: integer (nullable = false)
//  |    |    |-- p:q: string (nullable = true)

val rdd = df.rdd

val newSchema = renameAllColumns(df.schema, ":", "_")

spark.createDataFrame(rdd, newSchema).printSchema
// root
//  |-- c1: integer (nullable = false)
//  |-- c2_0: string (nullable = true)
//  |-- c3: struct (nullable = true)
//  |    |-- j: integer (nullable = false)
//  |    |-- m: struct (nullable = true)
//  |    |    |-- i: integer (nullable = false)
//  |    |    |-- p_q: string (nullable = true)

请注意,由于方法replaceAll可以理解Regex模式,因此可以应用该方法来修剪以char':'开头的列名,例如:

val newSchema = renameAllColumns(df.schema, """:.*""", "")

spark.createDataFrame(rdd, newSchema).printSchema
// root
//  |-- c1: integer (nullable = false)
//  |-- c2: string (nullable = true)
//  |-- c3: struct (nullable = true)
//  |    |-- j: integer (nullable = false)
//  |    |-- m: struct (nullable = true)
//  |    |    |-- i: integer (nullable = false)
//  |    |    |-- p: string (nullable = true)