如何保存Keras的培训历史以进行交叉验证(循环)?

时间:2017-04-10 13:19:34

标签: python deep-learning keras

  

对于交叉验证,如何保存不同训练集和交叉验证集的训练历史记录?我认为'一个'附加模式的泡菜写将起作用,但实际上它没有用。如果可能的话,请你告诉我保存所有模型的方法,现在我只能用model.save(file)保存最后训练过的模型。

historyfile = 'history.pickle'
f = open(historyfile,'w')
f.close()
ind = 0
save = {}
for train, test in kfold.split(input,output):
    ind = ind+1
    #create model
    model = model_FCN()
    # fit the model
    history = model.fit(input[list(train)], output[list(train)], batch_size = 16, epochs = 100, verbose =1, validation_data =(input[list(test)],output[list(test)]))
    #save to file 
    try:
        f = open(historyfile,'a') ## appending mode??
        save['cv'+ str(ind)]= history.history
        pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
        f.close()
    except Exception as e:
        print('Unable to save data to', historyfile, ':', e)

    scores = model.evaluate(MR_patch[list(test)], CT_patch[list(test)], verbose=0)
    print("%s: %.2f" % (model.metrics_names[1], scores[1]))
    cvscores.append(scores[1])
    print("cross validation stage: " + str(ind))

print("%.2f (+/- %.2f)" % (np.mean(cvscores), np.std(cvscores)))

1 个答案:

答案 0 :(得分:0)

要在特定列车的每个纪元后保存模型并验证数据,您可以使用Callback

例如:

from keras.callbacks import ModelCheckpoint
import os

output_directory = '' # here should be path to output directory    
model_checkpoint = ModelCheckpoint(os.path.join(output_directory , 'weights.{epoch:02d}-{val_loss:.2f}.hdf5'))
model.fit(input[list(train)],
          output[list(train)],
          batch_size=16,
          epochs=100,
          verbose=1,
          validation_data=(input[list(test)],output[list(test)]),
          callbacks=[model_checkpoint])

每个纪元后,您的模型将保存在文件中。有关此回调的更多信息,请参阅文档(https://keras.io/callbacks/

如果您想保存每个折叠上训练的模型,您只需在for循环中添加model.save(文件):

model.fit(input[list(train)],
          output[list(train)],
          batch_size=16,
          epochs=100,
          verbose=1,
          validation_data=(input[list(test)],output[list(test)]))
model.save(os.path.join(output_directory, 'fold_{}_model.hdf5'.format(ind)))

保存历史记录: 您可以将历史记录保存一次,而不必将其附加到每个循环上的文件中。在for循环之后,您应该获得带有键(折叠标记)和值(每个折叠的历史记录)的字典并保存此字典,如下所示:

f = open(historyfile, 'wb')
pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
f.close()