使用predict
馈送模型时,在Keras模型tf.data.Dataset
函数上循环时遇到内存泄漏,并且性能下降,而在以numpy数组馈送模型时却没有。
有人知道导致此问题的原因和/或如何解决此问题的方法吗?
最小的可复制代码段(可复制/粘贴可运行):
import tensorflow as tf
import numpy as np
import time
SIZE = 5000
inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)
model = tf.keras.Model(inputs=inp, outputs=x)
np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()
debug_time = time.time()
while True:
model.predict(x=ds, steps=1)
print('Processing {:.2f}'.format(time.time() - debug_time))
debug_time = time.time()
结果:预测每次循环的时间大约为0.04s,在一两分钟之内达到0.5s,并且进程内存从数百MB持续增加到接近GB。
将tf.data.Dataset
换成等效的numpy数组,运行时间始终约为0.01s。
工作案例代码段(可复制/粘贴可运行):
import tensorflow as tf
import numpy as np
import time
SIZE = 5000
inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)
model = tf.keras.Model(inputs=inp, outputs=x)
np_data = np.random.rand(1, SIZE)
debug_time = time.time()
while True:
model.predict(x=np_data) # using numpy array directly
print('Processing {:.2f}'.format(time.time() - debug_time))
debug_time = time.time()
相关讨论:
inter_op_paralellism
不会影响此处发布的结果。其他信息:
training_utils.py:1314
中注意到,Keras代码正在为每个预测调用创建一个迭代器。TF 1.14.0
答案 0 :(得分:0)
问题的根源似乎是Keras在每个predict
循环中创建数据集操作。请注意,在training_utils.py:1314
的每个预测循环中都会创建一个数据集迭代器。
可以通过传递迭代器来减轻问题的严重性,并且可以通过传递迭代器get_next()
张量来完全解决该问题。
我已在Tensorflow Github页面上发布了该问题:https://github.com/tensorflow/tensorflow/issues/30448
这是解决方案,此示例使用TF数据集在恒定时间内运行,只是无法传递数据集对象:
import tensorflow as tf
import numpy as np
import time
SIZE = 5000
inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)
model = tf.keras.Model(inputs=inp, outputs=x)
np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()
it = tf.data.make_one_shot_iterator(ds)
tensor = it.get_next()
debug_time = time.time()
while True:
model.predict(x=tensor, steps=1)
print('Processing {:.2f}'.format(time.time() - debug_time))
debug_time = time.time()