如何使用scikit-learn加载以前保存的模型并使用新的训练数据扩展模型

时间:2014-11-03 12:19:56

标签: python machine-learning scikit-learn

我使用scikit-learn我已经保存了一个逻辑回归模型,其中unigrams作为训练集1中的特征。是否可以加载此模型,然后使用新的数据实例从第二个扩展它训练集(训练集2)?如果是,怎么办呢?这样做的原因是因为我对每个训练集使用两种不同的方法(第一种方法涉及特征损坏/正规化,第二种方法涉及自我训练)。

为了清晰起见,我添加了一些简单的示例代码:

from sklearn.linear_model import LogisticRegression as log
from sklearn.feature_extraction.text import CountVectorizer as cv
import pickle

trainText1 # Training set 1 text instances    
trainLabel1 # Training set 1 labels 
trainText2 # Training set 2 text instances    
trainLabel2 # Training set 2 labels 

clf = log()
# Count vectorizer used by the logistic regression classifier 
vec = cv() 

# Fit count vectorizer with training text data from training set 1
vec.fit(trainText1) 

# Transforms text into vectors for training set1
train1Text1 = vec.transform(trainText1) 

# Fitting training set1 to the linear logistic regression classifier 
clf.fit(trainText1,trainLabel1)

# Saving logistic regression model from training set 1
modelFileSave = open('modelFromTrainingSet1', 'wb')
pickle.dump(clf, modelFileSave)
modelFileSave.close()  

# Loading logistic regression model from training set 1    
modelFileLoad = open('modelFromTrainingSet1', 'rb')
clf = pickle.load(modelFileLoad)

# I'm unsure how to continue from here....

1 个答案:

答案 0 :(得分:4)

LogisticRegression在内部使用不支持增量拟合的liblinear解算器。相反,您可以使用SGDClassifier(loss='log')作为partial_fit方法,尽管在实践中可以使用它。其他超参数是不同的。小心网格搜索其最佳值仔细。阅读SGDClassifier文档,了解这些超参数的含义。

CountVectorizer不支持增量拟合。您必须重新使用火车组#1上安装的矢量化器来转换#2。这意味着虽然#1中未见过的#2集中的任何令牌都将被完全忽略。这可能不是你所期望的。

为了缓解这种情况,您可以使用无状态的HashingVectorizer,代价是不知道功能的含义。请阅读the documentation了解详情。