对于循环Spark数据帧

时间:2016-07-21 10:35:44

标签: scala apache-spark

我有一个Dataframe df,其中包含一列groupID;也就是说,每个观察都属于一个特定的群体。总共有8组。我想从每个groupID中抽取一定百分比的观察结果(比方说,20%)。以下是我这样做的方法:

val sample_df = for ( i <- Array.range(0,7) ) yield {  
             val sel_df = df.filter($"groupID"===i)  
             sel_df.sample(false,0.2,seed1)  
             }  

此代码的结果是:

Array[org.apache.spark.sql.DataFrame] = Array([text: string, groupID: int], [text: string, groupID: int])

我在flatMap()上应用了sample_df,但我收到了错误消息:

val flat_df = sample_df.flatMap(x => x)
         <console>:59: error: type mismatch;
         found: org.apache.spark.sql.DataFrame
         required: scala.collection.GenTraversableOnce[?]

如何获取采样数据帧?

3 个答案:

答案 0 :(得分:2)

据我了解,您正试图获得RDD Row。为此,您只需致电:

val rows: RDD[Row] = sample_df.rdd

为了解释你变得更好的错误,flatMap需要像Option这样的可遍历的东西,但你只提供了一个Row

此外,要将所有数据提供给驱动程序,您可以调用:

val rows: Array[Row] = sample_df.collect

答案 1 :(得分:1)

我猜你想在每个小组上均匀采样。

sample_df.reduceLeft((result, df) => result.unionAll(df))

答案 2 :(得分:0)

在我看来,您只想采用整个数据帧的20%样本?如果是这样,那么就没有理由创建8个不同的数据帧然后将它们联合起来。

df.sample(false, 0.2, seed)

会做到这一点。如果您想为每个groupID执行不同的分数,请查看df.stat.sampleBy。如果您想确保样本中每个类的完全 20%,那么您必须转换为PairRDD并使用分层抽样,如:

df.rdd.map(row => (row(groupIDIndex), row)).sampleByKeyExact(false, Map(0 -> 0.2, 1 -> 0.2, ..., 8 -> 0.2), seed)