我是Keras的新手,并且总体上是tensorflow,并且有问题。我正在使用一些损失函数(主要是binary_crossentropy和mean_squared_error)来计算预测后的损失。由于Keras仅接受它自己的变量类型,因此我正在创建一个变量类型并将其作为参数提供。这种情况是这样循环执行的(带有睡眠):
获取适当的数据->预测->计算丢失的数据->返回它。
由于我有多个遵循此模式的模型,因此我创建了张量流图和会话以防止碰撞(同样,在导出模型的权重时,我对单个图和会话也有问题,因此我必须为每个单个模型创建不同的模型)。
但是,现在内存不可控制地增加了,从两次MiB升级到700MiB。我知道Keras的clear_session()和gc.collect(),并且在每次迭代结束时都使用它们,但是问题仍然存在。在这里,我提供了项目中的代码片段,而不是实际的代码。我创建了单独的脚本来解决问题:
import tensorflow as tf
from keras import backend as K
from keras.losses import binary_crossentropy, mean_squared_error
from time import time, sleep
import gc
from numpy.random import rand
from os import getpid
from psutil import Process
from csv import DictWriter
from keras import backend as K
this_process = Process(getpid())
graph = tf.Graph()
sess = tf.Session(graph=graph)
cnt = 0
max_c = 500
with open('/home/quark/Desktop/python-test/leak-7.csv', 'a') as file:
writer = DictWriter(file, fieldnames=['time', 'mem'])
writer.writeheader()
while cnt < max_c:
with graph.as_default(), sess.as_default():
y_true = K.variable(rand(36, 6))
y_pred = K.variable(rand(36, 6))
rec_loss = K.eval(binary_crossentropy(y_true, y_pred))
val_loss = K.eval(mean_squared_error(y_true, y_pred))
writer.writerow({
'time': int(time()),
'mem': this_process.memory_info().rss
})
K.clear_session()
gc.collect()
cnt += 1
print(max_c - cnt)
sleep(0.1)
此外,我添加了内存使用情况图: Keras memory leak
感谢您的帮助。
答案 0 :(得分:1)
我刚刚删除了with
语句(可能是一些tf代码),但没有看到任何泄漏。我相信keras会话和tf默认会话之间是有区别的。因此,您没有通过K.clear_session()
清除正确的会话。可能使用tf.reset_default_graph()
也可以。
while True:
y_true = K.variable(rand(36, 6))
y_pred = K.variable(rand(36, 6))
val_loss = K.eval(binary_crossentropy(y_true, y_pred))
rec_loss = K.eval(mean_squared_error(y_true, y_pred))
K.clear_session()
gc.collect()
sleep(0.1)
答案 1 :(得分:0)
最后,我要做的是从K.variable()
语句中删除where
代码。这样,变量便成为默认图形的一部分,之后将由K.clear_session()
清除。