Keras使用tf.data.Dataset而不是numpy数组预测循环内存泄漏

时间:2019-07-06 03:11:41

标签: python tensorflow keras

使用predict馈送模型时,在Keras模型tf.data.Dataset函数上循环时遇到内存泄漏,并且性能下降,而在以numpy数组馈送模型时却没有。

有人知道导致此问题的原因和/或如何解决此问题的方法吗?

最小的可复制代码段(可复制/粘贴可运行):

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()

debug_time = time.time()
while True:
    model.predict(x=ds, steps=1)
    print('Processing {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()

结果:预测每次循环的时间大约为0.04s,在一两分钟之内达到0.5s,并且进程内存从数百MB持续增加到接近GB。


tf.data.Dataset换成等效的numpy数组,运行时间始终约为0.01s。

工作案例代码段(可复制/粘贴可运行):

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)

debug_time = time.time()
while True:
    model.predict(x=np_data)  # using numpy array directly
    print('Processing {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()

相关讨论:


其他信息:

  • 通过传递迭代器而不是数据集对象,可以将性能下降的速度降低大约10倍。我在training_utils.py:1314中注意到,Keras代码正在为每个预测调用创建一个迭代器。

TF 1.14.0

1 个答案:

答案 0 :(得分:0)

问题的根源似乎是Keras在每个predict循环中创建数据集操作。请注意,在training_utils.py:1314的每个预测循环中都会创建一个数据集迭代器。

可以通过传递迭代器来减轻问题的严重性,并且可以通过传递迭代器get_next()张量来完全解决该问题。

我已在Tensorflow Github页面上发布了该问题:https://github.com/tensorflow/tensorflow/issues/30448

这是解决方案,此示例使用TF数据集在恒定时间内运行,只是无法传递数据集对象:

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()
it = tf.data.make_one_shot_iterator(ds)
tensor = it.get_next()

debug_time = time.time()
while True:
    model.predict(x=tensor, steps=1)
    print('Processing {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()