我有一个Dataframe df
,其中包含一列groupID
;也就是说,每个观察都属于一个特定的群体。总共有8组。我想从每个groupID
中抽取一定百分比的观察结果(比方说,20%)。以下是我这样做的方法:
val sample_df = for ( i <- Array.range(0,7) ) yield {
val sel_df = df.filter($"groupID"===i)
sel_df.sample(false,0.2,seed1)
}
此代码的结果是:
Array[org.apache.spark.sql.DataFrame] = Array([text: string, groupID: int], [text: string, groupID: int])
我在flatMap()
上应用了sample_df
,但我收到了错误消息:
val flat_df = sample_df.flatMap(x => x)
<console>:59: error: type mismatch;
found: org.apache.spark.sql.DataFrame
required: scala.collection.GenTraversableOnce[?]
如何获取采样数据帧?
答案 0 :(得分:2)
val rows: RDD[Row] = sample_df.rdd
为了解释你变得更好的错误,flatMap需要像Option
这样的可遍历的东西,但你只提供了一个Row
。
此外,要将所有数据提供给驱动程序,您可以调用:
val rows: Array[Row] = sample_df.collect
答案 1 :(得分:1)
我猜你想在每个小组上均匀采样。
sample_df.reduceLeft((result, df) => result.unionAll(df))
答案 2 :(得分:0)
在我看来,您只想采用整个数据帧的20%样本?如果是这样,那么就没有理由创建8个不同的数据帧然后将它们联合起来。
df.sample(false, 0.2, seed)
会做到这一点。如果您想为每个groupID执行不同的分数,请查看df.stat.sampleBy
。如果您想确保样本中每个类的完全 20%,那么您必须转换为PairRDD并使用分层抽样,如:
df.rdd.map(row => (row(groupIDIndex), row)).sampleByKeyExact(false, Map(0 -> 0.2, 1 -> 0.2, ..., 8 -> 0.2), seed)