如何在2个不同的DataFrame中添加相应的Integer值

时间:2017-03-09 12:52:57

标签: scala apache-spark

我的代码中有两个DataFrames具有完全相同的维度,比方说1,000,000 X 50.我需要在两个数据帧中添加相应的值。如何实现这一点。

一种选择是添加另一个包含ID的列,union两个DataFrame,然后使用reduceByKey。但还有其他更优雅的方式吗?

感谢。

1 个答案:

答案 0 :(得分:1)

你的方法很好。另一种选择可以是两个选择RDD并将它们压缩在一起,然后迭代这些选项以对列进行求和并使用任何原始数据帧模式创建新的数据帧。 假设所有列的数据类型都是整数,则此代码片段应该有效。请注意,这已在spark 2.1.0中完成。

    import spark.implicits._

    val a: DataFrame = spark.sparkContext.parallelize(Seq(
      (1, 2),
      (3, 6)
    )).toDF("column_1", "column_2")

    val b: DataFrame = spark.sparkContext.parallelize(Seq(
      (3, 4),
      (1, 5)
    )).toDF("column_1", "column_2")

    // Merge rows
    val rows = a.rdd.zip(b.rdd).map{
      case (rowLeft, rowRight) => {
        val totalColumns = rowLeft.schema.fields.size
        val summedRow = for(i <- (0 until totalColumns)) yield rowLeft.getInt(i) + rowRight.getInt(i)
        Row.fromSeq(summedRow)
      }
    }

    // Create new data frame
    val ab: DataFrame = spark.createDataFrame(rows, a.schema) // use any of the schemas
    ab.show()

更新: 所以,我试着尝试我的解决方案与你的解决方案的性能。我测试了100000行,每行有50列。如果你的方法有51列,额外的一列是ID列。在一台机器(没有集群)中,我的解决方案似乎运行得更快。

  1. 联合和分组方法大约需要5598毫秒。
  2. 我的解决方案大约需要5378毫秒。 我的假设是第一个解决方案需要花费更多时间,因为两个数据帧的并集操作。
  3. 以下是我为测试方法而创建的方法。

      def option_1()(implicit spark: SparkSession): Unit = {
        import spark.implicits._
        val a: DataFrame = getDummyData(withId = true)
        val b: DataFrame = getDummyData(withId = true)
        val allData = a.union(b)
    
        val result = allData.groupBy($"id").agg(allData.columns.collect({ case col if col != "id" => (col, "sum") }).toMap)
        println(result.count())
        //    result.show()
      }
    
    
      def option_2()(implicit spark: SparkSession): Unit = {
        val a: DataFrame = getDummyData()
        val b: DataFrame = getDummyData()
    
        // Merge rows
        val rows = a.rdd.zip(b.rdd).map {
          case (rowLeft, rowRight) => {
            val totalColumns = rowLeft.schema.fields.size
            val summedRow = for (i <- (0 until totalColumns)) yield rowLeft.getInt(i) + rowRight.getInt(i)
            Row.fromSeq(summedRow)
          }
        }
    
        // Create new data frame
        val result: DataFrame = spark.createDataFrame(rows, a.schema) // use any of the schemas
        println(result.count())
        //    result.show()
      }