在具有约束的Apache Spark(Scala)数据框中将布尔列转换为数值列?

时间:2017-10-31 18:28:52

标签: scala spark-dataframe

 val inputfile = sqlContext.read
        .format("com.databricks.spark.csv")
        .option("header", "true") 
        .option("inferSchema", "true") 
        .option("delimiter", "\t")
        .load("data")
 inputfile: org.apache.spark.sql.DataFrame = [a: string, b: bigint, c: boolean]
 val outputfile = inputfile.groupBy($"a",$"b").max($"c")

上面的代码失败,因为c是一个布尔变量,聚合不能应用于布尔值。 Spark中是否有一个函数可以将true值转换为1并将false转换为0以获取Spark数据框的完整列。

我尝试了以下内容(来源:How to change column types in Spark SQL's DataFrame?

 val inputfile = sqlContext.read
        .format("com.databricks.spark.csv")
        .option("header", "true") 
        .option("inferSchema", "true") 
        .option("delimiter", "\t")
        .load("data")
 val tempfile =inputfile.select("a","b","c").withColumn("c",toInt(inputfile("c")))   
 val outputfile = tempfile.groupBy($"a",$"b").max($"c")

以下问题:PySpark的Casting a new derived column in a DataFrame from boolean to integer答案,但我想要一个专门针对Scala的函数。

感谢任何帮助。

3 个答案:

答案 0 :(得分:2)

您无需使用udf即可执行此操作。如果要将布尔值转换为int,可以将列转换为int

val df2 = df1
  .withColumn("boolAsInt",$"bool".cast("Int")

答案 1 :(得分:1)

implicit def bool2int(b:Boolean) = if (b) 1 else 0

scala> false:Int
res4: Int = 0

scala> true:Int
res5: Int = 1

scala> val b=true
b: Boolean = true


scala> 2*b+1
res2: Int = 3

使用上述功能并注册为UDF

val bool2int_udf = udf(bool2int _)

val tempfile =inputfile.select("a","b","c").withColumn("c",bool2int_UDF($("c")))

答案 2 :(得分:1)

下面的代码对我有用。 @ Achyuth的回答提供了部分功能。然后,从这个问题中提出想法:Applying function to Spark Dataframe Column 我能够使用UDF将来自Achyuth的函数应用于数据框的完整列。这是完整的代码。

 implicit def bool2int(b:Boolean) = if (b) 1 else 0
 val bool2int_udf = udf(bool2int _)
 val inputfile = sqlContext.read
        .format("com.databricks.spark.csv")
        .option("header", "true") 
        .option("inferSchema", "true") 
        .option("delimiter", "\t")
        .load("data") 
 val tempfile = inputfile.select("a","b","c").withColumn("c",bool2int_udf($"c"))
 val outputfile = tempfile.groupBy($"a",$"b").max($"c")