假设下一个Pyspark自定义转换器:
class CustomTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
def __init__(self, output_col):
self.output_col = output_col
self.feat_cols = None
super(CustomTransformer, self).__init__()
def _transform(self, df):
self.feat_cols = get_match_columns(df, "ops")
# Do something smart here with this feat_cols
df = df.drop(*self.feat_cols)
return df
其中feat_cols
是在_transform()
方法内计算和设置的,而get_match_columns
是一个函数,它返回与某些模式匹配的列名。包含此转换器的管道已转换后,我需要访问此参数,例如:
pipeline = Pipeline(stages=[custom_transformer, assembler])
myPipe = pipeline.fit(data)
result = myPipe.transform(data)
使用某些方法,例如:
result.stages[0].getParam('feat_cols')
但是,显然,它不起作用。我尝试遵循此wrapper,在我的转换器中对此getter进行编码:
def getFeatCols(self):
return self.getOrDefault(self.feat_cols)
但是我仍然无法恢复该参数(result.stages[0]._java_obj.getParam('feat_cols')
都可以使用)。
在Pyspark中有什么方法可以解决这个问题?
答案 0 :(得分:0)
正如@ user10938362在评论中指出的那样,有必要使用Param。在这种情况下,对我有用的代码是:
from pyspark.ml.param import Param
class CustomTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
def __init__(self, output_col):
super(CustomTransformer, self).__init__()
self.output_col = output_col
self.feat_cols = Param(self, "feat_cols", "Feature columns")
self._set(feat_cols=[]) # set or _set depends on the Spark version
def _transform(self, df):
self._set(feat_cols=get_match_columns(df, "ops"))
# Do something smart here with this feat_cols
df = df.drop(*self.getFeatCols())
return df
def getFeatCols(self):
return self.getOrDefault("feat_cols")