Spark GBTClassifier始终以100%的准确度预测

时间:2018-04-11 21:55:36

标签: apache-spark machine-learning pyspark apache-spark-mllib apache-spark-ml

我使用SparkML GBTClassifier来训练二维分类问题的广泛数据集:

Xtrain.select(labelCol).groupBy(labelCol).count().orderBy(labelCol).show()
+-----+------+
|label| count|
+-----+------+
|    0|631608|
|    1| 18428|
+-----+------+

va = VectorAssembler(inputCols=col_header, outputCol="features")
tr = GBTClassifier(labelCol=labelCol, featuresCol="features", maxIter=30, maxDepth=5, seed=420)
pipeline = Pipeline(stages=[va, tr])
model = pipeline.fit(Xtrain)

分类器运行速度非常快(不寻常),并且100%准确地学习,超过测试集也可以100%准确地预测。当我打印

model.stages[1].featureImportances
SparseVector(29, {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0, 9: 0.0, 10: 0.0, 11: 0.0, 12: 0.0, 13: 0.0, 14: 0.0, 15: 0.0, 16: 0.0, 17: 0.0, 18: 0.0, 19: 0.0, 20: 0.0, 21: 0.0, 22: 0.0, 23: 0.0, 24: 1.0, 25: 0.0, 26: 0.0, 27: 0.0, 28: 0.0})

我注意到我的DataFrame中的一个功能(在这种情况下为#24)为模型贡献了100%的权重。当我删除此字段并重新训练时,我看到相同的图片,唯一的区别是第二个字段现在正在为模型做出贡献,我得到100%的准确度。显然有些事情是不对的,它是什么?

1 个答案:

答案 0 :(得分:1)

非简并数据集中最常见的行为原因是数据泄漏。数据泄漏可以采取不同的形式,但考虑到

  

我的DataFrame中的一个功能(在这种情况下为#24)贡献了100%的权重

我们可以大大缩小范围:

  • 一个简单的编码错误 - 您在功能中包含了标签(或转换后的标签)。你应该仔细检查你的处理管道。
  • 原始数据包含用于派生标签或从标签派生的功能。您应该检查数据字典(如果存在)或其他可用来源,以确定应从模型中丢弃哪些功能(通常查找任何内容,您不会在原始数据中预期)。