Keras中每个model.fit()的CPU使用率和直到开始训练的时间

时间:2019-03-22 12:22:15

标签: python tensorflow keras

我用Keras API创建了一个LSTM。现在,当我尝试测试其中的不同值时(即学习率f.e.),我遇到了一个问题。每当我更改值并定义新模型时,模型花费的时间就会越来越长,直到训练开始时等待时间的CPU使用率为100%。我做错什么了吗,以至于旧的学习课程影响了新模型?

我的代码结构如下,在一个文件中,我调用一个具有不同值和许多迭代的评估,如下所示:

for i in range(0, 100):
    acc = model.create(xtrain, ytrain, hidden_units=hidden_size, batch_size=batch_size, learning_rate=learning_rate, l2_reg=l2_reg)

model是另一个文件。在这里,我使用传递的值来训练新模型,并传回精度以找到最佳的批次大小等。用于模型创建的代码如下:

def create(xtrain, ytrain, hidden_units, batch_size, learning_rate, l2_reg):
    # defining some layers from input to output
    # example: input = Input(shape=(20,)) ...

    # creating the model
    model = Model(inputs=[input], output=[outputs])
    model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['acc'])

    # calling model.fit
    es = EarlyStopping(monitor='val_loss', mode='min', patience=4, verbose=1)
    model.fit(xtrain, ytrain, epochs=100, batch_size=batch_size, validation_data=(some_xval_data, some_yval_data), callbacks=[es])

    ## In the end I evaluate the model on unseen data and return the accuracy
    loss, acc = model.evaluate(x_testdata, y_testdata, batch_size=batch_size)
    return acc

现在每次模型开始训练脚本打印时:

Epoch 1/100

在第一次评估调用时,模型立即开始训练,我看到每个步骤花费的时间。但是过了一段时间,在打印“ Epoch 1/100”之后,突然开始需要一段时间才能开始训练。而且时间从一个电话到另一个电话增加。在等待培训真正开始时,我可以观察到那段时间我的CPU使用率为100%。

那么在每次再次调用该方法时我做错了吗?在“创建”效果较旧的调用中是否存在某些过程?我只是希望旧的培训不会对我的代码结构产生影响?

2 个答案:

答案 0 :(得分:1)

感谢@Fedor Petrov和@desertnaut。

他们在另一个答案的注释中讨论了我必须调用函数clear_session

from keras.backend import clear_session

def create():
    # do all the model stuff
    # evaluate the model
    clear_session() 
    return

现在,我可以根据需要多次拨打create(),而不会发生内存泄漏。

答案 1 :(得分:0)

评估期间爆炸的内存使用量已经是一个已知问题。 Re-trained keras model evaluation leaks memory when called in a loop 通常在回调中定期进行评估时会注意到这一点。在您的情况下,您只需调用一次valuate(...)100次,这也足以观察到该问题。

自从我在云中训练了模型之后,我就通过增加实例的RAM来“解决”了类似的问题。

UPD。 这是我前一段时间的知识。以下卓有成效的讨论得出以下答案(如此处https://github.com/keras-team/keras/issues/2102所述):

   keras.backend.clear_session()

运行此命令将从tf.Graph中删除不必要的信息