我有以下代码,用于训练模型并将其保存到hickle文件中(也可以是任何类型的文件)
from keras import Sequential
from keras.layers import Dense
from keras.models import load_model
import hickle as hkl
import numpy as np
class Model:
def __init__(self, data=None):
self.data = data
self.metrics = []
self.model = self.__build_model()
def __build_model(self):
model = Sequential()
model.add(Dense(4, activation='relu', input_shape=(3,)))
model.add(Dense(1, activation='relu'))
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])
return model
def train(self, epochs):
self.model.fit(self.data[:, :-1], self.data[:,-1], epochs=epochs)
return self
def test(self, data):
self.metrics = self.model.evaluate(data[:, :-1], data[:, -1])
return self
def predict(self, input):
return self.model.predict(input)
def save(self, path):
data = {'metrics': self.metrics, 'k_model': self.model.get_config()}
hkl.dump(data, path, mode='w')
return self
def load(self, path):
data = hkl.load('model.hkl')
self.metrics = data['metrics']
self.model = Sequential.from_config(data['k_model'])
return self
def train():
train_data = np.random.rand(1000, 4)
test_data = np.random.rand(100, 4)
print("TRAINING, TESTING & SAVING..")
model = Model(train_data)\
.train(epochs=5)\
.test(test_data)\
.save('./model.hkl')
print('metrics: ', model.metrics)
conf = model.model.get_config()
print("type: ", type(conf))
print("length: ", len(conf))
if __name__ == '__main__':
train()
print('USING SAVED MODEL..')
model = Model()
model.load('./model.hkl')
print(model.metrics)
这将显示错误
TypeError: type object argument after ** must be a mapping, not PyContainer
怎么了?错误形式是keras还是来自ckle锁?
NB。这里我只是保存指标,但是它可以包含其他任何信息
谢谢。
答案 0 :(得分:0)
我在另一个问题上回答了类似的问题。这就是我总是保存keras模型的方式:
model.save('model.h5')
model_json = model.to_json()
with open("model.json", "w") as json_file:
json_file.write(model_json)
用于加载保存的模型:
model = load_model('model_w1.h5')
打印摘要:
model.summary()
要再次训练,您可以在加载后直接使用fit
。
如果要在某些应用程序中使用它,则首先使其成为全局(无需为每个预测重新加载)加载模型,然后使其权重:
def load_model():
global model
json_file = open('model.json', 'r')
model_json = json_file.read()
model = model_from_json(model_json)
model.load_weights("model.h5")
model._make_predict_function()