如何在sklearn中保存自定义变换器?

时间:2017-09-06 14:25:20

标签: python session scikit-learn pickle

我无法加载使用sklearn.externals.joblib.dumppickle.dump保存的自定义转换器的实例,因为当前python会话中缺少自定义转换器的原始定义。

假设在一个python会话中,我定义,创建并保存自定义变换器,它也可以在同一个会话中加载:

from sklearn.base import TransformerMixin
from sklearn.base import BaseEstimator
from sklearn.externals import joblib

class CustomTransformer(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y=None):
        return self

    def transform(self, X, y=None):
        return X


custom_transformer = CustomTransformer()    
joblib.dump(custom_transformer, 'custom_transformer.pkl')

loaded_custom_transformer = joblib.load('custom_transformer.pkl')

打开新的python会话并从' custom_transformer.pkl'

加载
from sklearn.externals import joblib

joblib.load('custom_transformer.pkl')

引发以下异常:

AttributeError: module '__main__' has no attribute 'CustomTransformer'

如果joblib替换为pickle,则会发生同样的事情。使用

将自定义转换器保存在一个会话中
with open('custom_transformer_pickle.pkl', 'wb') as f:
    pickle.dump(custom_transformer, f, -1)

并将其加载到另一个:

with open('custom_transformer_pickle.pkl', 'rb') as f:
    loaded_custom_transformer_pickle = pickle.load(f)

提出了同样的例外。

在上文中,如果CustomTransformer被替换为sklearn.preprocessing.StandardScaler,则会发现已保存的实例可以在新的python会话中加载。

是否可以保存自定义变压器并在以后的其他位置加载?

3 个答案:

答案 0 :(得分:4)

sklearn.preprocessing.StandardScaler有效,因为类定义在sklearn软件包安装中可用,joblib在加载pickle时会查找。

您必须通过重新定义或导入新会话,使CustomTransformer课程可用。

答案 1 :(得分:0)

如果我在sklearn.preprocessing.FunctionTranformer()中传递了转换函数,并且使用dill.dump()dill.load将一个“ .pk”文件保存了模型,则对我有用。

注意:我的分类器已将tranform函数包含在sklearn管道中。

答案 2 :(得分:0)

我没有使用 sklearn.externals.joblib 而只是使用 joblib 模块,它有效:

示例:

from sklearn.base import BaseEstimator, TransformerMixin

class CustomNgramVectorize(BaseEstimator, TransformerMixin):
    """Vectorizes texts as n-gram vectors"""
    def __init__(self, text, reduce=True):
        # Create keyword arguments to pass to the 'tf-idf' vectorizer.
        kwargs = {
                'ngram_range': NGRAM_RANGE,  # Use 1-grams + 2-grams.
                'dtype': 'int32',
                'strip_accents': 'unicode',
                'decode_error': 'replace',
                'max_features' : 1000, #limit number of words
                'sublinear_tf': True, # Apply sublinear tf scaling
                'stop_words' : stopwords.words('french'),# drop french stopwords
                'analyzer': TOKEN_MODE,  # Split text into word tokens.
                'min_df': MIN_DOCUMENT_FREQUENCY,
        }
        self.tfidf_vectorizer = TfidfVectorizer(**kwargs)
        self.reduce = reduce
        if self.reduce:
            self.svd = TruncatedSVD(n_components=25, n_iter=25, random_state=12)
    
    def fit(self, X, y=None):
        self.tfidf_vectorizer.fit(X)
            
    def transform(self, X, y=None):
        X = self.tfidf_vectorizer.transform(X)
        # convert to dataframe
        X_df = pd.DataFrame(X.toarray(), columns=sorted(self.tfidf_vectorizer.vocabulary_))
        if self.reduce:
            X_df = self.svd.fit_transform(X_df)
        return X_df

然后使用 joblib.dump 函数保存它:

# persist model
import joblib
joblib.dump(vectorizer, 'custom_tfidf_vectorizer.joblib')

稍后使用 joblib.load 函数检索它:

var='route_of_administration'
v = joblib.load('custom_tfidf_vectorizer.joblib')
v.fit(train[var])
X_df = v.transform(train[var])

enter image description here