使用k.clear_session()和tf.reset_default_graph()清除后续模型之间的图形

时间:2019-02-01 10:06:45

标签: python tensorflow keras deep-learning

随后加载多个模型时,我似乎无法正确清除图形。

RunJob

只是在第一个模型加载后关闭Python中的程序。 如果删除上面的行,我可以加载后续模型,但随后会遇到内存泄漏。

k.clear_session()   
tf.reset_default_graph()

我对k.clear_session()和tf.reset_default()的使用不正确吗?

谢谢。

更新:

我尝试如下更改代码,但仍然遇到相同的问题:

>>> import keras
Using TensorFlow backend.
>>> keras.__version__
'2.2.4'
>>> import tensorflow as tf
>>> tf.__version__
'1.8.0'
>>> 





def evaluate_models(models_path_dir):
    models_paths = [os.path.join(models_path_dir, model) for model in os.listdir(models_path_dir) if model.endswith(".hdf5")]
    models_pairs = get_model_key(models_paths, global_model_keys)
    print(len(model_pairs)) #15
    for model_pair in models_pairs:
        model_path,model_key = model_pair
        img_height, img_width = 480, 480
        evaluate_validation_data(model_path, model_key)



def evaluate_validation_data(model_path,model_key):
    preprocess =  model_key
    valid_datagen = ImageDataGenerator(preprocessing_function = preprocess)
    valid_generator = valid_datagen.flow_from_directory(
    validation_data_dir,
    target_size = (img_height, img_width),
    batch_size = 30, 
    class_mode = 'categorical',
    shuffle = False)

    model = load_model(model_path)
    print("model path",model_path)
    print("image size", (img_height, img_width))
    print( model.evaluate_generator(valid_generator))
    k.clear_session()
    tf.reset_default_graph()

这是程序执行时发生的事情:

def evaluate_validation_data(model_path,model_key):
        preprocess =  model_key
        valid_datagen = ImageDataGenerator(preprocessing_function = preprocess)
        valid_generator = valid_datagen.flow_from_directory(
        validation_data_dir,
        target_size = (img_height, img_width),
        batch_size = 10, 
        class_mode = 'categorical',
        shuffle = False)

        model = load_model(model_path)
        print("model path",model_path)
        print("image size", (img_height, img_width))
        print( model.evaluate_generator(valid_generator))
        k.clear_session()
        #tf.reset_default_graph()




>>> import keras
Using TensorFlow backend.
>>> keras.__version__
'2.2.4'
>>> import tensorflow as tf
>>> tf.__version__
'1.8.0'
>>> 

然后关闭

2 个答案:

答案 0 :(得分:0)

松开tf.reset_default_graph(),您应该会很好。 至于内存泄漏,请确保您正在运行Keras 2.2.4(最好是tensorflow> = 1.10具有更好的keras集成),当依次加载多个模型时,我遇到了Keras 2.2.2崩溃的类似问题,并且在我消失后消失了已更新为Keras 2.2.4。

答案 1 :(得分:0)

似乎Keras高于2.2且tf 1.8时存在错误吗?

https://github.com/keras-team/keras/issues/10399

我需要将Keras降级到2.1吗?

编辑:

刚刚测试。 降级它2.1照顾该错误。