从合适的自定义pyspark变压器获取参数

时间:2019-03-06 19:43:37

标签: apache-spark pyspark transformer

假设下一个Pyspark自定义转换器:

class CustomTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):

    def __init__(self, output_col):
        self.output_col = output_col
        self.feat_cols = None
        super(CustomTransformer, self).__init__()

    def _transform(self, df):

        self.feat_cols = get_match_columns(df, "ops")
        # Do something smart here with this feat_cols
        df = df.drop(*self.feat_cols)

        return df

其中feat_cols是在_transform()方法内计算和设置的,而get_match_columns是一个函数,它返回与某些模式匹配的列名。包含此转换器的管道已转换后,我需要访问此参数,例如:

pipeline = Pipeline(stages=[custom_transformer, assembler])
myPipe = pipeline.fit(data)
result = myPipe.transform(data)

使用某些方法,例如:

result.stages[0].getParam('feat_cols')

但是,显然,它不起作用。我尝试遵循此wrapper,在我的转换器中对此getter进行编码:

def getFeatCols(self):
        return self.getOrDefault(self.feat_cols)

但是我仍然无法恢复该参数(result.stages[0]._java_obj.getParam('feat_cols')都可以使用)。

在Pyspark中有什么方法可以解决这个问题?

1 个答案:

答案 0 :(得分:0)

正如@ user10938362在评论中指出的那样,有必要使用Param。在这种情况下,对我有用的代码是:

from pyspark.ml.param import Param

class CustomTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):

    def __init__(self, output_col):
        super(CustomTransformer, self).__init__()
        self.output_col = output_col
        self.feat_cols = Param(self, "feat_cols", "Feature columns")
        self._set(feat_cols=[]) # set or _set depends on the Spark version


    def _transform(self, df):
        self._set(feat_cols=get_match_columns(df, "ops"))
        # Do something smart here with this feat_cols
        df = df.drop(*self.getFeatCols())

        return df

    def getFeatCols(self):

        return self.getOrDefault("feat_cols")