更改sklearn管道的参数

时间:2020-09-18 23:26:51

标签: python scikit-learn pipeline

我有一个Pipeline,我已经使用pickle对其进行了培训和保存。它包含以下步骤:

Pipeline(steps=[('preprocessing': Preprocessor()),
('my_transformer': my_transformer()), 
('model': XGBClassifier())
])

我想在执行my_transformer()时记录一些信息,但是仅当我预测概率时,即我运行pipeline.predict_proba()时,才记录一些信息。我不希望在运行pipeline.predict()时执行日志记录行。

my_transformer()如下所示:

class my_transformer(BaseEstimator, TransformerMixin):
  def __init__(self, flag_log=False):
    self.flag_log = flag_log

  def transform(self, features):
    #apply transformations
    if self.flag_log:
      logger.info("log probabilities")

我要做的是根据是否要记录信息来修改flag_log的值。基本上,我想要这样的东西:

pipeline.set_params(my_transformer__flag_log=True)
probabilities = pipeline.predict_proba(features)
pipeline.set_params(my_transformer__flag_log=False)
predictions = pipeline.predict(features)

我尝试了上面的代码,但是它不起作用,flag_log的值不变。 还有其他解决方案吗?

1 个答案:

答案 0 :(得分:0)

要修改flag_log的值,只需运行:

pipeline.set_params(my_transformer__flag_log=False)

那应该对我有用。否则,如果您确定它不能那样工作,请提供您的代码minimal, reproducible example,以便其他人可以重现您的问题(我会研究它,然后更新我的答案)。