使用自定义转换器时如何正确腌制sklearn管道

时间:2019-09-11 11:36:03

标签: python scikit-learn persistence pipeline joblib

我正在尝试腌制一个sklearn机器学习模型,并将其加载到另一个项目中。该模型包装在具有编码,缩放等功能的管道中。当我想在管道中使用自写转换器执行更高级的任务时,问题就开始了。

假设我有2个项目:

  • train_project:在src.feature_extraction.transformers.py
  • 中具有自定义转换器
  • use_project:它在src中包含其他内容,或者根本没有src目录

如果在“ train_project”中使用joblib.dump()保存管道,然后在“ use_project”中通过joblib.load()加载管道,则找不到“ src.feature_extraction.transformers”之类的东西并抛出例外:

  

ModuleNotFoundError:没有名为“ src.feature_extraction”的模块

我还应该补充一点,我的初衷是简化模型的使用,以便程序员可以像加载其他模型一样加载模型,传递非常简单的,人类可读的功能,以及对功能进行所有“神奇”的预处理内部正在发生模型(例如梯度增强)。

我想到了在两个项目的根目录中创建/ dependencies / xxx_model /目录,并在其中存储所有需要的类和函数(将代码从“ train_project”复制到“ use_project”),因此项目的结构是相等的,并且转换器可以被加载。我发现此解决方案非常不雅致,因为它将强制使用该模型的任何项目的结构。

我想到了只是在“ use_project”中重新创建管道和所有变压器,并以某种方式从“ train_project”中加载变压器的拟合值。

最好的解决方案是如果转储的文件包含所有需要的信息并且不需要依赖,我真的为sklearn.Pipelines感到震惊-如果我无法加载,则适合管道的意义何在?以后对象?是的,如果我仅使用sklearn类,而不创建自定义类,但非自定义类没有所有必需的功能,那将是可行的。

示例代码:

train_project

src.feature_extraction.transformers.py

from sklearn.pipeline import TransformerMixin
class FilterOutBigValuesTransformer(TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y=None):
        self.biggest_value = X.c1.max()
        return self

    def transform(self, X):
        return X.loc[X.c1 <= self.biggest_value]

train_project

main.py

from sklearn.externals import joblib
from sklearn.preprocessing import MinMaxScaler
from src.feature_extraction.transformers import FilterOutBigValuesTransformer

pipeline = Pipeline([
    ('filter', FilterOutBigValuesTransformer()),
    ('encode', MinMaxScaler()),
])
X=load_some_pandas_dataframe()
pipeline.fit(X)
joblib.dump(pipeline, 'path.x')

test_project

main.py

from sklearn.externals import joblib

pipeline = joblib.load('path.x')

预期结果是使用正确的转换方法正确加载了管道。

加载文件时,实际结果是异常。

4 个答案:

答案 0 :(得分:1)

我找到了一个非常简单的解决方案。假设您正在使用Jupyter笔记本进行培训:

  1. 在定义自定义转换器的地方创建一个.py文件,并将其导入Jupyter笔记本。

这是文件custom_transformer.py

from sklearn.pipeline import TransformerMixin

class FilterOutBigValuesTransformer(TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y=None):
        self.biggest_value = X.c1.max()
        return self

    def transform(self, X):
        return X.loc[X.c1 <= self.biggest_value]
  1. 训练模型从.py文件导入此类并使用joblib保存。
import joblib
from custom_transformer import FilterOutBigValuesTransformer
from sklearn.externals import joblib
from sklearn.preprocessing import MinMaxScaler

pipeline = Pipeline([
    ('filter', FilterOutBigValuesTransformer()),
    ('encode', MinMaxScaler()),
])

X=load_some_pandas_dataframe()
pipeline.fit(X)

joblib.dump(pipeline, 'pipeline.pkl')
  1. 在不同的python脚本中加载.pkl文件时,您必须导入.py文件以使其工作:
import joblib
from utils import custom_transformer # decided to save it in a utils directory

pipeline = joblib.load('pipeline.pkl')

答案 1 :(得分:1)

根据我的研究,似乎最好的解决方案是创建一个包含受过训练的管道和所有文件的Python包。

然后,您可以将其pip安装在要使用它的项目中,并使用from <package name> import <pipeline name>导入管道。

答案 2 :(得分:0)

我创建了一种解决方法。我不认为这是对我的问题的完整答案,但即使如此,它仍然使我从问题中继续前进。

该变通办法起作用的条件:

I。管道仅需要两种变压器:

  1. sklearn变压器
  2. 自定义转换器,但仅具有以下类型的属性:
    • 号码
    • 字符串
    • 列表
    • dict

或这些的任意组合,例如包含字符串和数字的字典列表。通常重要的是属性可以json序列化。

II。管线步骤的名称必须唯一(即使存在管线嵌套)


简而言之,模型将与joblib转储文件,用于自定义转换器的json文件以及具有关于模型的其他信息的json文件一起存储为目录。

我创建了一个函数,该函数将通过管道的各个步骤并检查变压器的__module__属性。

如果在其中找到sklearn,则它将在步骤(步骤元组的第一个元素)中指定的名称下运行joblib.dump函数到某些选定的模型目录。

否则(在__module__中没有sklearn),它在等于步骤中指定的名称的键下将转换器的__dict__添加到result_dict。最后,我将json.dump的result_dict转储到名称为result_dict.json的模型目录中。

如果需要使用某些变压器,例如管道中有一个管道,您可以通过在函数的开头添加一些规则来递归地运行此函数,但是即使在主管道和子管道之间始终具有唯一的步骤/变压器名称也很重要。

如果创建模型管道还需要其他信息,则将其保存在model_info.json中。


然后,如果要加载模型以供使用: 您需要在目标项目中创建(无拟合)相同的管道。如果管道创建有些动态,并且您需要源项目中的信息,请从model_info.json加载它。

您可以复制用于序列化的功能,并且:

  • 用joblib.load语句替换所有joblib.dump,将__dict__从加载的对象分配给管道中已经存在的对象__dict __
  • 替换从__dict__添加到result_dict的所有位置,并从result_dict分配适当的值到对象__dict__(记住要事先从文件中加载result_dict)

运行此修改后的功能后,先前未拟合的管道应加载所有具有拟合效果的变压器属性,并且整个管道都可以进行预测。

对于该解决方案,我最不喜欢的是它需要目标项目中的管道代码,并且需要自定义转换器的所有属性都可以json序列化,但是我留给其他偶然发现类似问题的人使用,也许有人想出了更好的东西。

答案 3 :(得分:0)

您尝试过使用泡菜吗? https://github.com/cloudpipe/cloudpickle