我在做什么
我正在训练并使用卷积神经元网络(CNN)进行图像分类,使用Keras和Tensorflow-gpu作为后端。
我正在使用的
- PyCharm社区2018.1.2
- Python 2.7和3.5(但不是一次两个)
- Ubuntu 16.04
- Keras 2.2.0
- Tensorflow-GPU 1.8.0作为后端
我想知道的事情
在许多代码中,我看到人们使用
from keras import backend as K
# Do some code, e.g. train and save model
K.clear_session()
或使用后删除模型:
del model
keras文档说明clear_session
:"销毁当前的TF图并创建一个新图。有助于避免旧模型/层的混乱。" - https://keras.io/backend/
这样做有什么意义,我也应该这样做?在加载或创建新模型时,我的模型无论如何都会被覆盖,所以为什么要这么麻烦?
答案 0 :(得分:9)
K.clear_session()在连续创建多个模型时(例如在超参数搜索或交叉验证期间)很有用。您训练的每个模型都会向图中添加节点(可能以数千为单位)。每当您(或Keras)调用tf.Session.run()或tf.Tensor.eval()时,TensorFlow都会执行整个图形,因此模型的训练速度将越来越慢,并且内存也可能用完。清除会话会删除以前模型中剩余的所有节点,从而释放内存并防止速度变慢。
答案 1 :(得分:2)
在交叉验证期间,我想运行number_of_replicates
倍(也称为重复)折叠以获取平均验证损失,以此作为与另一种算法比较的基础。因此,我需要针对两种单独的算法执行交叉验证,并且我有多个GPU可用,因此认为这不会成为问题。
不幸的是,我开始看到图层名称在自己的丢失日志中附加了_2
,_3
等内容。我还注意到,如果我通过在单个脚本中使用循环来依次执行重复项(也称为折叠),则会耗尽GPU的内存。
这个策略对我有用;我现在在Ubuntu lambda机器上的tmux
会话中已经连续运行了几个小时,有时会看到内存泄漏,但是它们被超时功能杀死了。它需要估计完成每个交叉验证折叠/复制所需的时间。在该代码下面的代码中,该数字为timeEstimateRequiredPerReplicate
(最好将通过循环的次数加倍,以防其中一半被杀死):
from multiprocessing import Process
# establish target for process workers
def machine():
import tensorflow as tf
from tensorflow.keras.backend import clear_session
from tensorflow.python.framework.ops import disable_eager_execution
import gc
clear_session()
disable_eager_execution()
nEpochs = 999 # set lower if not using tf.keras.callbacks.EarlyStopping in callbacks
callbacks = ... # establish early stopping, logging, etc. if desired
algorithm_model = ... # define layers, output(s), etc.
opt_algorithm = ... # choose your optimizer
loss_metric = ... # choose your loss function(s) (in a list for multiple outputs)
algorithm_model.compile(optimizer=opt_algorithm, loss=loss_metric)
trainData = ... # establish which data to train on (for this fold/replicate only)
validateData = ... # establish which data to validate on (same caveat as above)
algorithm_model.fit(
x=trainData,
steps_per_epoch=len(trainData),
validation_data=validateData,
validation_steps=len(validateData),
epochs=nEpochs,
callbacks=callbacks
)
gc.collect()
del algorithm_model
return
# establish main loop to start each process
def main_loop():
for replicate in range(replicatesDesired - replicatesCompleted):
print(
'\nStarting cross-validation replicate {} '.format(
replicate +
replicatesCompleted + 1
) +
'of {} desired:\n'.format(
replicatesDesired
)
)
p = Process(target=process_machine)
p.start()
flag = p.join(timeEstimateRequiredPerReplicate)
print('\n\nSubprocess exited with code {}.\n\n'.format(flag))
return
# enable running of this script from command line
if __name__ == "__main__":
main_loop()
答案 2 :(得分:1)
del将删除python中的变量,并且由于model是变量,所以del model将删除它,但是TF图将没有变化(TF是您的Keras后端)。也就是说,K.clear_session()将销毁当前的TF图并创建一个新的TF图。创建新模型似乎是一个独立的步骤,但请不要忘记后端:)