如何"保存" Python中的IsolationForest模型?

时间:2017-11-07 11:34:05

标签: python machine-learning scikit-learn outliers

嘿我正在使用sklearn.ensemble.IsolationForest来预测我的数据的异常值。

是否可以将模型训练(适合)一次到我的干净数据,然后将其保存以便以后使用? 例如,为了保存模型的某些属性,所以下次再次调用fit函数来训练我的模型是不必要的。

例如,对于GMM,我会保存每个组件的weights_means_covs_,所以以后我不需要训练再次建模。

为了明确这一点,我将其用于在线欺诈检测,其中这个python脚本将被多次调用同一个"类别"数据,我不想在我需要执行预测或测试操作时训练模型。

提前致谢。

2 个答案:

答案 0 :(得分:0)

__getstate__估算器实现了一些方法,可以让您轻松保存估算器的相关训练属性。一些估算器本身实现了GMM方法,但是其他方法,例如def __getstate__(self): try: state = super(BaseEstimator, self).__getstate__() except AttributeError: state = self.__dict__.copy() if type(self).__module__.startswith('sklearn.'): return dict(state.items(), _sklearn_version=__version__) else: return state 只使用base implementation,它只是保存对象内部字典:

from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
    pickle.dump(model,f)

将模型保存到光盘的建议方法是使用pickle模块:

IsolationForest

但是,您应该保存其他数据,以便将来重新训练您的模型,或遭受可怕的后果(例如被锁定在旧版本的sklearn中)

来自documentation

  

为了使用未来版本重建类似的模型   scikit-learn,应该沿着pickled保存额外的元数据   型号:

     

训练数据,例如对不可变快照的引用

     

用于生成模型的python源代码

     

scikit-learn及其依赖项的版本

     

在训练数据上获得的交叉验证分数

对于依赖于用Cython编写的tree.pyx模块(例如joblib)的Ensemble估算器尤其如此,因为它创建了与实现的耦合,不保证在sklearn版本之间保持稳定。它在过去看到了倒退不兼容的变化。

如果您的模型变得非常大并且加载变得令人讨厌,您也可以使用效率更高的pickle。来自文档:

  

在scikit的特定情况下,使用它可能更有趣   joblib取代joblib.dumpjoblib.load& // Android SDK implementation( 'com.facebook.android:facebook-android-sdk:4.+' ){ exclude group: 'com.google.android.gms' } // Audience Network SDK. Only versions 4.6.0 and above are available implementation( 'com.facebook.android:audience-network-sdk:4.+'){ exclude group: 'com.google.android.gms' } // Account Kit implementation( 'com.facebook.android:account-kit-sdk:4.+'){ exclude group: 'com.google.android.gms' } ),这是   在内部携带大型numpy数组的对象上效率更高   通常情况下适合的scikit-learn估算器,但只能   pickle到磁盘而不是字符串:

答案 1 :(得分:0)

https://docs.python.org/2/library/pickle.html

使用Pickle库。

适合你的模特。

使用pickle.dump(obj, file[, protocol])

保存

使用pickle.load(file)

加载它

对异常值进行分类