XGBoost在标签列上使用开窗功能后失败

时间:2018-07-11 20:47:08

标签: scala apache-spark xgboost windowing

我已经成功训练了一个XGBoost模型,其中trainDF是一个包含两列的数据帧:featureslabel,其中我们有11k 1和57M 0(不平衡数据集)。一切正常。

val udnersample = 0.1
// Undersampling of 0's -- choosing 10%
val training1 = output1.filter($"datestr" < end_period1 && 
    $"label" === 1)
val training0 = output1.filter($"datestr" < end_period1 && 
    $"label" === 0).sample(
    false, undersample)
val training = training0.unionAll(training1)
val traindDF = training.select("label", 
    "features").toDF("label", "features")}
val paramMap = List("eta" -> 0.05,
                    "max_depth" -> 6,
                    "objective" -> "binary:logistic").toMap
val num_trees = 400
val num_cores = 200
val XGBModel = XGBoost.trainWithDataFrame(trainDF, 
                                          paramMap, 
                                          num_trees, 
                                          num_cores, 
                                          useExternalMemory = true)

然后,我想通过一些窗口更改y标签,以便在每个组中可以更早地预测y标签。

val sum_label = "sum_label"
val label_window_length = 19
val sliding_window_label =  Window.partitionBy("id").orderBy(
    asc("timestamp")).rowsBetween(0, label_window_length)

val training_source = output1.filter($"datestr" < 
    end_period1).withColumn(
    sum_label, sum($"label").over(sliding_window_label)).drop(
    "label").withColumnRenamed(sum_label, "label")
val training1 = training_source.filter(col("label") === 1)
val training0 = training_source.filter(col("label") === 0).sample(false, 0.099685)
val training = training0.unionAll(training1)
val traindDF = training.select("label", 
    "features").toDF("label", "features")}

结果具有57M 0和214k 1(尽管行数大致相同)。 NA的{​​{1}}列中没有"label",并且类型仍为trainDF。然后xgboost失败:

double (nullable=true)

我可以根据需要添加日志。我的困惑是,使用窗口功能并且实际上不更改任何其他设置会导致XGB失败。我对此表示感谢。

1 个答案:

答案 0 :(得分:0)

事实证明,将表traindDF保存在配置单元中并将其重新加载到Spark中可以解决此问题:

traindDF.write.mode("overwrite").saveAsTable("database.tablename")

然后,您可以轻松地加载表格:

val traindDF = spark.sql("""select * from database.tablename""")

此技巧解决了问题。似乎火花开窗功能有点不稳定,将结果保存到配置单元表中使其可以工作。

一种更好的方法是在蜂巢中使用窗口函数而不是Spark。