Spark / Scala 1.6如何使用数据帧groupby agg来实现以下逻辑?

时间:2017-04-11 11:38:12

标签: scala apache-spark apache-spark-sql spark-dataframe

如何使用数据帧groupby agg来实现以下逻辑?

ID    ID2   C1   C2   C3   C4   C5   C6 .....C33
CM1    a    1    1    1    0    0    0
CM2    a    1    1    0    1    0    0
CM3    a    1    0    1    1    1    0
CM4    a    1    1    1    1    1    0
CM5    a    1    1    1    1    1    0
1k2    b    0    0    1    1    1    0
1K3    b    1    1    1    1    1    0
1K1    b    0    0    0    0    1    0

我希望我的输出df看起来像这样

ID    ID2   C1   C2   C3   C4   C5   C6 .....C33
CM1    a    1    1    1    0    0    0
CM2    a    0    0    0    1    0    0
CM3    a    0    0    0    0    1    0
CM4    a    0    0    0    0    0    0
CM5    a    0    0    0    0    0    0
1K1    b    0    0    0    0    1    0
1k2    b    0    0    1    1    0    0
1K3    b    1    1    0    0    0    0

逻辑基于ID2 do group by,然后在Cn为1时找到最小ID,然后设置为1,其他则设置为0.

Cn达到C33。

如果用例类超出限制。

我尝试过使用mapPartitions

但结果是错误的......

使用Spark 1.6.0

添加我尝试过的代码

case class testGoods(ID: String, ID2: String, C1 : String, C2 : String)

val cartMap = new HashMap[String, Set[(String,String,String)]] with MultiMap[String,(String,String,String)]

val baseDF=hiveContext.sql(newSql)

val testRDD=baseDF.mapPartitions( partition => {
  while (partition.hasNext) {
    val record = partition.next()
    val ID = record.getString(0)
    if (ID != null && ID != "null") {
      val ID2=record.getString(1)
      val C1=record.getString(2)
      val C2=record.getString(3)
      cartMap.addBinding(ID2, (ID,C1,C2))
    }
  }
  cartMap.iterator
})

val recordList = new mutable.ListBuffer[testGoods]()
val testRDD1=testRDD.mapPartitions( partition => {
  while (partition.hasNext) {
    val record = partition.next()
    val ID2=record._1
    val recordRow= record._2
    val sortedRecordRow = TreeSet[(String,String,String)]() ++ recordRow
    val dic=new mutable.HashMap[String,String]


    for(v<-sortedRecordRow) {
      val ID = v._1
      val C1 = v._2
      val C2 = v._3

      if (dic.contains(ID2)){
        val goodsValue=dic.get(ID2)
        if("1".equals(goodsValue)){
          recordList.append(new testGoods(ID, ID2, "0", C2))
        }else{
          dic.put(ID2,C1)
          recordList.append(new testGoods(ID, ID2, C1,C2))
        }
      }else{
        dic.put(ID2,C1)
        recordList.append(new testGoods(ID, ID2, C1, C2))
      }
    }
  }
  recordList.iterator
})

再次编辑

原始数据集有数百万个ID,按ID分组后,每个ID2可能有2~300个数据。

2 个答案:

答案 0 :(得分:0)

这是我如何解决这个问题的算法。绝对有足够的改进空间。 基本上我建议解决它的两个阶段。

  1. 为数据集中的每个1构建一个ID地图。

  2. 根据此地图绘制完整数据框 - 用0替换所有不满意的记录

  3. val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("data.csv")

    val Limit = 5 // or 33
    
    val re = (1 to Limit).map { i =>
        val col = "C" + i
        val first = df.filter(df(col) > 0).head
        val id  = first(0) // id column name
        i -> id
    }.toMap
    // result of this stage is smth like:Map(5 -> CM3, 1 -> CM1, 2 -> CM1, 3 -> CM1, 4 -> CM2)    
    
    df.map(columns => 
         Seq(columns(0), columns(1)) ++ 
            (1 to Limit).map { x => 
                 if (columns(0) == re(x)) 1 else 0 
          }).foreach(println)
    

答案 1 :(得分:0)

基本上如果您的子表很小(正如您在评论中提到的2~300个数据点),您只需要这样:

val columnIds = List(2, 3, 4, 5, 6, 7)// (preColumns to preColumns + numColumns).toList
val columnNames = List("C1", "C2", "C3", "C4", "C5", "C6") //just for representability

val withKey = df.rdd.map(c => c.getString(1) -> c).groupByKey

val res = withKey.flatMap{ 
  case (id, local) =>
    case class Accumulator(found: Set[Int] = Set.empty, result: List[Row] = List.empty)    
    local.foldLeft(Accumulator()){
      case (acc, row) => 
        val found = columnIds.filter(id => row.getInt(id) != 0) //columns with `1`
        val pre = Seq(row(0), row(1))
        val res = pre ++ columnIds.map{ cid => 
          if (acc.found.contains(cid)) 0 else row.get(cid)
        }
        Accumulator(acc.found ++ found, Row.fromSeq(res) :: acc.result)
    }.result.reverse
}
  • 您可以使用比reverse更合适的集合来删除List累加器(例如Queue
  • Accumulator这里介绍的是为了避免在表示我们已经找到的列时的可变变量&#34; 1&#34;。但是,您可以在Spark中使用var内部lambda(但不在外部!) - 它也可以使用。
  • Accumulator将已处理的列保存在found中,结果本身保存在result
  • groupByKey之所以被使用是因为您无法在Spark中访问rdd中的rdd,您只能在流内部执行此操作。它也是一个更智能的&#34;替换代码中的mapPartitions - 它可以理解&#34;按键分组&#34;而不是分区。
  • local是您的数据框的本地子集,您基本上可以将其作为常规Scala - 集合进行操作。

RDD转换回DataFrame

val newDf = sqlContext.createDataFrame(res, df.schema)

实验:

data.csv:
ID,ID2,C1,C2,C3,C4,C5,C6
CM1,a,1,1,1,0,0,0
CM2,a,1,1,0,1,0,0
CM3,a,1,0,1,1,1,0
1k2,b,0,0,1,1,1,0
1K3,b,1,1,1,1,1,0
1K1,b,0,0,0,0,1,0

val sqlContext = new SQLContext(sc)

val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("data.csv")
...
res.collect()

res77_4: Array[Row] = Array(
[CM1,a,1,1,1,0,0,0],
  [CM2,a,0,0,0,1,0,0],
  [CM3,a,0,0,0,0,1,0],
  [1k2,b,0,0,1,1,1,0],
  [1K3,b,1,1,0,0,0,0],
  [1K1,b,0,0,0,0,0,0]
)

对于处理稀疏数据(有很多零)的其他情况 - 考虑使用Mllib的矩阵:https://spark.apache.org/docs/2.1.0/mllib-data-types.html#distributed-matrix。他们可以更有效地保存和处理稀疏结构。

我还建议使用案例类重新定位Row以避免使用列索引。