我需要什么K.clear_session()和del model(Keras with Tensorflow-gpu)?

时间:2018-06-17 08:49:39

标签: python tensorflow memory-management keras

我在做什么
我正在训练并使用卷积神经元网络(CNN)进行图像分类,使用Keras和Tensorflow-gpu作为后端。

我正在使用的
- PyCharm社区2018.1.2
- Python 2.7和3.5(但不是一次两个)
- Ubuntu 16.04
- Keras 2.2.0
- Tensorflow-GPU 1.8.0作为后端

我想知道的事情
在许多代码中,我看到人们使用

from keras import backend as K 

# Do some code, e.g. train and save model

K.clear_session()

或使用后删除模型:

del model

keras文档说明clear_session:"销毁当前的TF图并创建一个新图。有助于避免旧模型/层的混乱。" - https://keras.io/backend/

这样做有什么意义,我也应该这样做?在加载或创建新模型时,我的模型无论如何都会被覆盖,所以为什么要这么麻烦?

3 个答案:

答案 0 :(得分:9)

K.clear_session()在连续创建多个模型时(例如在超参数搜索或交叉验证期间)很有用。您训练的每个模型都会向图中添加节点(可能以数千为单位)。每当您(或Keras)调用tf.Session.run()或tf.Tensor.eval()时,TensorFlow都会执行整个图形,因此模型的训练速度将越来越慢,并且内存也可能用完。清除会话会删除以前模型中剩余的所有节点,从而释放内存并防止速度变慢。

答案 1 :(得分:2)

在交叉验证期间,我想运行number_of_replicates倍(也称为重复)折叠以获取平均验证损失,以此作为与另一种算法比较的基础。因此,我需要针对两种单独的算法执行交叉验证,并且我有多个GPU可用,因此认为这不会成为问题。

不幸的是,我开始看到图层名称在自己的丢失日志中附加了_2_3等内容。我还注意到,如果我通过在单个脚本中使用循环来依次执行重复项(也称为折叠),则会耗尽GPU的内存。

这个策略对我有用;我现在在Ubuntu lambda机器上的tmux会话中已经连续运行了几个小时,有时会看到内存泄漏,但是它们被超时功能杀死了。它需要估计完成每个交叉验证折叠/复制所需的时间。在该代码下面的代码中,该数字为timeEstimateRequiredPerReplicate(最好将通过循环的次数加倍,以防其中一半被杀死):

from multiprocessing import Process

# establish target for process workers
def machine():
    import tensorflow as tf
    from tensorflow.keras.backend import clear_session

    from tensorflow.python.framework.ops import disable_eager_execution
    import gc

    clear_session()

    disable_eager_execution()  
    nEpochs = 999 # set lower if not using tf.keras.callbacks.EarlyStopping in callbacks
    callbacks = ... # establish early stopping, logging, etc. if desired

    algorithm_model = ... # define layers, output(s), etc.
    opt_algorithm = ... # choose your optimizer
    loss_metric = ... # choose your loss function(s) (in a list for multiple outputs)
    algorithm_model.compile(optimizer=opt_algorithm, loss=loss_metric)

    trainData = ... # establish which data to train on (for this fold/replicate only)
    validateData = ... # establish which data to validate on (same caveat as above)
    algorithm_model.fit(
        x=trainData,
        steps_per_epoch=len(trainData),
        validation_data=validateData,
        validation_steps=len(validateData),
        epochs=nEpochs,
        callbacks=callbacks
    )

    gc.collect()
    del algorithm_model

    return


# establish main loop to start each process
def main_loop():
    for replicate in range(replicatesDesired - replicatesCompleted):
        print(
            '\nStarting cross-validation replicate {} '.format(
                replicate +
                replicatesCompleted + 1
            ) +
            'of {} desired:\n'.format(
                replicatesDesired
            )
        )
        p = Process(target=process_machine)
        p.start()
        flag = p.join(timeEstimateRequiredPerReplicate)
        print('\n\nSubprocess exited with code {}.\n\n'.format(flag))
    return


# enable running of this script from command line
if __name__ == "__main__":
    main_loop()

答案 2 :(得分:1)

del将删除python中的变量,并且由于model是变量,所以del model将删除它,但是TF图将没有变化(TF是您的Keras后端)。也就是说,K.clear_session()将销毁当前的TF图并创建一个新的TF图。创建新模型似乎是一个独立的步骤,但请不要忘记后端:)