GBTClassifier如何处理用于二进制分类的不平衡数据?

时间:2019-05-17 13:25:51

标签: apache-spark apache-spark-ml

我想使用GBTClassifier对不平衡的数据集执行二进制分类。 我没有看到spark documentation允许这样做的任何选择。

有人通过指定我们的数据不平衡这一事实对使用GBTClassifier有想法吗?

谢谢

注意:我使用的是Spark 2.3.2

1 个答案:

答案 0 :(得分:0)

这是我的幼稚解决方案:随机降低多数类的采样率。 这种解决方案的缺点是信息丢失,并且不适用于小型数据集。

val resampledTrainDF = {

    val positiveLabel = "1"
    val trainDF_positives = trainDF.where(F.col(label) === positiveLabel)
    val trainDF_negatives = trainDF.where(F.col(label) =!= positiveLabel)

    val withReplacement = trainDF_positives.count >= trainDF_negatives.count

    if (withReplacement) {
        // downsampling positives
        val sampSize = math.round(  (1.0 * trainDF_negatives.count / trainDF_positives.count) * 1000) / 1000.0
        println("Downsampling Positives by " + (1 - sampSize)*100 + " %")
        trainDF_positives.sample(false, sampSize).union(trainDF_negatives)
    } else { 
        //downsampling negatives
        val sampSize = math.round(  (1.0 * trainDF_positives.count / trainDF_negatives.count) * 1000) / 1000.0
        println("Downsampling Negatives by " + (1 - sampSize)*100 +  "%")
        trainDF_negatives.sample(false, sampSize).union(trainDF_positives)
    }

}