我们如何从先前保存的Keras模型绘制准确性和损失图?

时间:2018-06-23 18:38:50

标签: python-3.x keras

是否有一种方法可以从之前保存的CNN模型中绘制准确性和损失图? 还是只能在训练和评估模型期间绘制图形?

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(_NUM_CLASSES, activation='softmax'))
model.compile(optimizer='rmsprop', loss='categorical_crossentropy',metrics= 
              ["accuracy"])
model.fit(x_train, y_train,
          batch_size=_BATCH_SIZE,
          epochs=_EPOCHS,
          verbose=1,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
model.save('model.h5')

2 个答案:

答案 0 :(得分:3)

在Keras中保存模型的可用选项均未包含培训历史记录,这正是您在此处要的。为了使此历史记录可用,您必须对培训代码进行一些琐碎的修改,以便分别保存。这是一个基于Keras MNIST example和仅3个训练时期的可复制示例:

hist = model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=3,
          verbose=1,
          validation_data=(x_test, y_test))

hist是Keras回调,它包含一个history字典,其中包含您要查找的指标:

hist.history
# result:
{'acc': [0.9234666666348775, 0.9744000000317892, 0.9805999999682109],
 'loss': [0.249011807457606, 0.08651042315363884, 0.06568188704450925],
 'val_acc': [0.9799, 0.9843, 0.9876],
 'val_loss': [0.06219216037504375, 0.04431889447008725, 0.03649089169385843]}

即每个训练时期(这里是3个)的训练和验证指标(这里是损失和准确性)。

现在使用Pickle保存此字典并根据需要将其还原是很简单的:

import pickle

# save:
f = open('history.pckl', 'wb')
pickle.dump(hist.history, f)
f.close()

# retrieve:    
f = open('history.pckl', 'rb')
history = pickle.load(f)
f.close()

在此处进行简单检查即可确认原始变量和检索到的变量确实相同:

hist.history == history
# True

答案 1 :(得分:0)

这取决于您保存模型的方式。

通常有两种情况,第一种是保存和加载整个模型(包括体系结构和权重):

$sql = "SELECT * from [table_name]
WHERE ... AND
Priority = '" . mssql_escape($urlPriority) . "' AND
etc AND
etc ";

第二个是仅保存权重:

from keras.models import load_model

model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
...
model = load_model('my_model.h5')

有关更多详细信息,请阅读Keras documentation