Scala中Spark数据帧的有序拆分

时间:2020-06-05 08:21:56

标签: scala apache-spark apache-spark-sql

我在spark中有一个sql.DataFrame,我想用scala将其拆分为训练和测试数据帧。 我现在正在使用此代码:

val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

但是我不想随机分割数据框。我想要的是将其中的70%作为培训,其余作为测试。最好的方法是什么?

2 个答案:

答案 0 :(得分:2)

如果您的id列是单调的,并且您知道最小值/最大值(或可以使用df.agg(min(col("id")), max(col("id")).collect()之类的值来找到它们),那么您可以简单地找到您的“分割值”,例如{ {1}}(如果我没记错的话)。然后,您可以使用过滤器将df拆分为:

(maxValue - minValue) * 0.7 + minValue

答案 1 :(得分:2)

Pi带@ rayan-ral的答案-或者,如果您希望Spark自己计算id列的最小/最大值,则可以repartition按范围{strong>} 放入N分区中,然后将0.7 * N partitions 作为您的训练集。

df = df.repartitionByRange(100,col("id"))
val train = df.filter(spark_partition_id() < 70)