将Scala Seq数据集拆分为训练和测试集

时间:2013-12-09 17:49:27

标签: scala

在Scala中,我有一个表示为Seq [T]的数据集。 我想将它分成训练数据集和测试数据集。 这可以基于简单的随机抽样。 我知道如何使用一对ListBuffers等来做到这一点。但这个公式是否也正确?

def splitIntoTrainingAndTest[T](all: Seq[T], samplingRate: Double): (Seq[T], Seq[T]) = {
  val r1 = new Random(123)
  val r2 = new Random(123)
  (
    all.filter({ i: T => r1.nextDouble() < samplingRate }),
    all.filter({ i: T => r2.nextDouble() >= samplingRate })
  )
}

请注意,我对java.util.Random个实例使用相同的随机种子。我只是想知道将来.filter(…)是否会被懒惰地实施......

2 个答案:

答案 0 :(得分:4)

您可能最好使用Seqpartition方法:

// Partitions this sequence in two sequences according to a predicate.
def partition(p: (A) ⇒ Boolean): (Seq[A], Seq[A])

然后你只需要一个RNG即可all.partition(_ => r.nextDouble() < samplingRate)

答案 1 :(得分:1)

对于延迟评估,您可以将其转换为allStream,然后使用partition根据给定的谓词将all拆分为流。在你的情况下,这将是:

all.toStream.partition(_ < samplingRate)

返回TupleStream s。

(scala.collection.immutable.Stream[Int], scala.collection.immutable.Stream[Int])