我已经完成了tensorflow文档,但是使用SavedModelBuilder类找不到以最佳验证精度保存模型的方法。 我正在使用tflearn进行模型构建,下面是我尝试过的工作,但是需要花费大量时间,我分别在每个时期运行适合的方法并保存模型
for i in range(epoch):
model.fit(trainX, trainY, n_epoch=1, validation_set=(testX, testY), show_metric=True, batch_size=8)
builder = tf.saved_model.builder.SavedModelBuilder('/tmp/serving/model/' + str(i))
builder.add_meta_graph_and_variables(model.session,
['TRAINING'],
signature_def_map={
'predict': prediction_sig
})
builder.save()
请建议是否有更好的方法。
答案 0 :(得分:1)
想出来。它可以通过回调来实现。 感谢。
class SaveModelCallback(tflearn.callbacks.Callback):
def __init__(self, accuracy_threshold):
self.accuracy_threshold = accuracy_threshold
self.accuracy = []
self.max_accuracy = -1
def on_epoch_end(self, training_state):
self.accuracy.append(training_state.global_acc)
if training_state.val_acc > self.accuracy_threshold and training_state.val_acc > self.max_accuracy:
self.max_accuracy = training_state.val_acc
epoch = training_state.epoch
self.save_model(epoch)
def save_model(self, epoch):
print('saved epoch ' + str(epoch))
builder = tf.saved_model.builder.SavedModelBuilder('/tmp/serving/model/' + str(epoch))
builder.add_meta_graph_and_variables(model.session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={
'predict': prediction_sig,
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
})
builder.save()
callback = SaveModelCallback(accuracy_threshold=0.8)
model.fit(trainX, trainY, n_epoch=200, validation_set=(testX, testY), show_metric=True, batch_size=8,
callbacks=callback)