我正在研究PySpark二进制分类管道,我想在其中执行过验证阶段的CrossValidation(我的数据集不平衡)。问题在于过采样阶段也会在测试数据集上执行。
管道:
(0, 1)
pipeline=Pipeline(stages=[cast_and_fill_na, smote, vec_assembler, rf])
是我在转换测试数据集时要跳过的阶段。
我查看了spark文档和源代码,无法跳过PipelineModel中的阶段。我的解决方案是重写原始类的smote
方法,以跳过监督阶段。
在我的源代码中适合管道时,这可以很好地工作。我用这个:
_transform
pipeline_model.__class__ = CustomPipelineModel
是从CustomPipelineModel
继承并覆盖pyspark.ml.PipelineModel
方法的类。
但是,由于CrossValidator使用PipelineModel类的原始实现,所以我不能使用自定义方法。
_transform
使用交叉验证器时,跳过过采样阶段的最佳方法是什么?
我开始研究evaluator = BinaryClassificationEvaluator(labelCol=target)
crossval = CrossValidator(estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=evaluator,
numFolds=10,
parallelism=1)
cvModel = crossval.fit(train_set)
的{{1}}方法的源代码,也考虑对其进行覆盖...第二种解决方案是对训练数据集执行过采样,但这会给交叉验证过程中的模型。
答案 0 :(得分:0)
我想出了解决此问题的方法。
在我的SMOTEOversmapler类(smote阶段是它的一个实例)中,我添加了一个名为skip_transform
的单元格,在实例化SMOTEOversmapler对象时将其设置为None。在_transform
方法中,我将此属性设置为True。对_transform
的下一个调用(处于测试阶段)将被跳过。这是一个代码段。
def __init__(self, ...):
self.skip_transfrom = None
def _transform(self, df):
if self.skip_transform:
retrun df
else:
#Execute oversampling
self.skip_transform = True