PySaprk管道中的交叉验证过采样

时间:2019-11-15 20:47:17

标签: python pyspark cross-validation oversampling smote

我正在研究PySpark二进制分类管道,我想在其中执行过验证阶段的CrossValidation(我的数据集不平衡)。问题在于过采样阶段也会在测试数据集上执行。

管道:

(0, 1)

pipeline=Pipeline(stages=[cast_and_fill_na, smote, vec_assembler, rf]) 是我在转换测试数据集时要跳过的阶段。

我查看了spark文档和源代码,无法跳过PipelineModel中的阶段。我的解决方案是重写原始类的smote方法,以跳过监督阶段。 在我的源代码中适合管道时,这可以很好地工作。我用这个:

_transform

pipeline_model.__class__ = CustomPipelineModel 是从CustomPipelineModel继承并覆盖pyspark.ml.PipelineModel方法的类。

但是,由于CrossValidator使用PipelineModel类的原始实现,所以我不能使用自定义方法。

_transform

使用交叉验证器时,跳过过采样阶段的最佳方法是什么?

我开始研究evaluator = BinaryClassificationEvaluator(labelCol=target) crossval = CrossValidator(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=10, parallelism=1) cvModel = crossval.fit(train_set) 的{​​{1}}方法的源代码,也考虑对其进行覆盖...第二种解决方案是对训练数据集执行过采样,但这会给交叉验证过程中的模型。

1 个答案:

答案 0 :(得分:0)

我想出了解决此问题的方法。 在我的SMOTEOversmapler类(smote阶段是它的一个实例)中,我添加了一个名为skip_transform的单元格,在实例化SMOTEOversmapler对象时将其设置为None。在_transform方法中,我将此属性设置为True。对_transform的下一个调用(处于测试阶段)将被跳过。这是一个代码段。

def __init__(self, ...):
    self.skip_transfrom = None
def _transform(self, df):
    if self.skip_transform:
         retrun df
    else:
         #Execute oversampling
         self.skip_transform = True