Keras variable()内存泄漏

时间:2018-11-08 09:23:59

标签: python tensorflow memory-leaks keras

我是Keras的新手,并且总体上是tensorflow,并且有问题。我正在使用一些损失函数(主要是binary_crossentropy和mean_squared_error)来计算预测后的损失。由于Keras仅接受它自己的变量类型,因此我正在创建一个变量类型并将其作为参数提供。这种情况是这样循环执行的(带有睡眠):

获取适当的数据->预测->计算丢失的数据->返回它。

由于我有多个遵循此模式的模型,因此我创建了张量流图和会话以防止碰撞(同样,在导出模型的权重时,我对单个图和会话也有问题,因此我必须为每个单个模型创建不同的模型)。

但是,现在内存不可控制地增加了,从两次MiB升级到700MiB。我知道Keras的clear_session()和gc.collect(),并且在每次迭代结束时都使用它们,但是问题仍然存在。在这里,我提供了项目中的代码片段,而不是实际的代码。我创建了单独的脚本来解决问题:

import tensorflow as tf

from keras import backend as K
from keras.losses import binary_crossentropy, mean_squared_error

from time import time, sleep
import gc
from numpy.random import rand

from os import getpid
from psutil import Process

from csv import DictWriter
from keras import backend as K

this_process = Process(getpid())

graph = tf.Graph()
sess = tf.Session(graph=graph)

cnt = 0
max_c = 500

with open('/home/quark/Desktop/python-test/leak-7.csv', 'a') as file:
    writer = DictWriter(file, fieldnames=['time', 'mem'])
    writer.writeheader()

    while cnt < max_c:  
        with graph.as_default(), sess.as_default():         
            y_true = K.variable(rand(36, 6))
            y_pred = K.variable(rand(36, 6))

            rec_loss = K.eval(binary_crossentropy(y_true, y_pred))
            val_loss = K.eval(mean_squared_error(y_true, y_pred))

            writer.writerow({
                'time': int(time()),
                'mem': this_process.memory_info().rss
            })

        K.clear_session()
        gc.collect()

        cnt += 1
        print(max_c - cnt)
        sleep(0.1)

此外,我添加了内存使用情况图: Keras memory leak

感谢您的帮助。

2 个答案:

答案 0 :(得分:1)

我刚刚删除了with语句(可能是一些tf代码),但没有看到任何泄漏。我相信keras会话和tf默认会话之间是有区别的。因此,您没有通过K.clear_session()清除正确的会话。可能使用tf.reset_default_graph()也可以。

while True: 
    y_true = K.variable(rand(36, 6))
    y_pred = K.variable(rand(36, 6))

    val_loss = K.eval(binary_crossentropy(y_true, y_pred))
    rec_loss = K.eval(mean_squared_error(y_true, y_pred))

    K.clear_session()
    gc.collect()

    sleep(0.1)

答案 1 :(得分:0)

最后,我要做的是从K.variable()语句中删除where代码。这样,变量便成为默认图形的一部分,之后将由K.clear_session()清除。