如何腌制Keras模型?

时间:2018-01-17 07:23:54

标签: python machine-learning keras pickle

官方文件声明“不建议使用pickle或cPickle保存Keras型号。”

然而,我对酸洗Keras模型的需求源于使用sklearn的RandomizedSearchCV(或任何其他超参数优化器)的超参数优化。将结果保存到文件中至关重要,因为脚本可以在分离的会话中远程执行等。

基本上,我想:

getWindow().getDecorView().setSystemUiVisibility(
        View.SYSTEM_UI_FLAG_LAYOUT_STABLE
        | View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN);

3 个答案:

答案 0 :(得分:5)

到目前为止,Keras型号是可腌制的。但是我们仍然建议使用model.save()将模型保存到磁盘。

答案 1 :(得分:3)

这就像魅力http://zachmoshe.com/2017/04/03/pickling-keras-models.html

import types
import tempfile
import keras.models

def make_keras_picklable():
    def __getstate__(self):
        model_str = ""
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            keras.models.save_model(self, fd.name, overwrite=True)
            model_str = fd.read()
        d = { 'model_str': model_str }
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            fd.write(state['model_str'])
            fd.flush()
            model = keras.models.load_model(fd.name)
        self.__dict__ = model.__dict__


    cls = keras.models.Model
    cls.__getstate__ = __getstate__
    cls.__setstate__ = __setstate__

make_keras_picklable()

PS。我有一些问题,由于我的model.to_json()由于循环引用而引发TypeError('Not JSON Serializable:', obj),并且上面的代码已经以某种方式吞下了这个错误,因此导致pickle函数永远运行。

答案 2 :(得分:3)

分别使用get_weights和set_weights保存和加载模型。

看看这个链接:Unable to save DataFrame to HDF5 ("object header message is too large")

#for heavy model architectures, .h5 file is unsupported.
weigh= model.get_weights();    pklfile= "D:/modelweights.pkl"
try:
    fpkl= open(pklfile, 'wb')    #Python 3     
    pickle.dump(weigh, fpkl, protocol= pickle.HIGHEST_PROTOCOL)
    fpkl.close()
except:
    fpkl= open(pklfile, 'w')    #Python 2      
    pickle.dump(weigh, fpkl, protocol= pickle.HIGHEST_PROTOCOL)
    fpkl.close()