从训练集数据

时间:2016-11-15 04:51:15

标签: scala apache-spark apache-spark-sql apache-spark-mllib

我的数据分为两个colorsexcluded_colors

colors包含所有颜色 excluded_colors包含一些我希望从训练集中排除的颜色。

我正在尝试将数据拆分为训练和测试集,并确保excluded_colors中的颜色不在我的训练集中但存在于测试集中。

为了达到上述目的,我做了这个

var colors = spark.sql("""
   select colors.* 
   from colors 
   LEFT JOIN excluded_colors 
   ON excluded_colors.color_id = colors.color_id
   where excluded_colors.color_id IS NULL
"""
)
val trainer: (Int => Int) = (arg:Int) => 0
val sqlTrainer = udf(trainer)
val tester: (Int => Int) = (arg:Int) => 1
val sqlTester = udf(tester)

val rsplit = colors.randomSplit(Array(0.7, 0.3))  
val train_colors = splits(0).select("color_id").withColumn("test",sqlTrainer(col("color_id")))
val test_colors = splits(1).select("color_id").withColumn("test",sqlTester(col("color_id")))

然而,我意识到通过上面这些excluded_colors中的颜色被完全忽略了。它们甚至不在我的测试集中。

问题 如何在70/30中拆分数据,同时确保excluded_colors中的颜色不在训练中但在测试中存在。

1 个答案:

答案 0 :(得分:1)

我们想要做的是删除"排除的颜色"从训练集开始,但让他们参加测试并进行70/30的训练/测试分组。

我们需要的是一些数学。

考虑到总数据集(TD)和排除颜色数据集(E),我们可以说对于火车数据集(Tr)和测试数据集(Ts):

|Tr| = x * (|TD|-|E|)
|Ts| = |E| + (1-x) * |TD|

我们也知道|Tr| = 0.7 |TD|

因此x = 0.7 |TD| / (|TD| - |E|)

现在我们知道了采样因子x,我们可以说:

Tr = (TD-E).sample(withReplacement = false, fraction = x)
// where (TD - E) is the result of the SQL expr above

Ts = TD.sample(withReplacement = false, fraction = 0.3)
// we sample the test set from the original dataset