循环中的Estimator.predict导致张量流中的内存泄漏

时间:2017-09-15 14:20:22

标签: tensorflow memory-leaks

当我使用张量流estimator.predict时,这发生在我身上。 比如,我通过以下方式从保存的模型中加载估算器:

estimator = tf.contrib.learn.Estimator(
    model_fn=model_fn, model_dir=FLAGS.model_dir, config=run_cfg)

get_input_fn()将返回input_fn,如下所示:

def get_input_fn(arg1, arg2):
    def input_fn():
        # do something
        #    ....
        return features, None
    return input_fn

然后,循环将用于预测来自file_iter的所有输入:

for idx, data in enumerate(file_iter):
    predicts = estimator.predict(input_fn=get_input_fn(data['query'],
                                                    data['responses']))

这会导致内存泄漏。每次调用estimator.predict后,内存会稍微增加,但永远不会下降。我使用objgraph来调试我的代码,并在每次调用estimator.predict后找到一些引用计数增加。

我真的不知道estimator.predict的洞察力。我想问题可能是因为我不止一次调用input_fn。我的张量流的版本是v1.2。

[UPDATE]

以下是objgraph的结果,左边是调用estimator.predict之前,mid是调用它之后,右边是另一个调用结果。如我所见,tuplelistdic在每次调用estimator.predict后稍微增加一点。我没有绘制参考图,因为我不熟悉它。

objgraph.show_most_common_types()    
tuple            146247 | tuple            180157   | tuple            213976
list             60745  | list             73107    | list             86111
dict             43412  | dict             50925    | dict             58437
function         28482  | function         28497    | function         28512
TensorShapeProto 9434   | TensorShapeProto 11793    | TensorShapeProto 14152
Dimension        8286   | Dimension        10360    | Dimension        12434
Operation        6098   | Operation        7625     | Operation        9152
AttrValue        6098   | NodeDef          7625     | NodeDef          9152
NodeDef          6098   | TensorShape      7575     | TensorShape      9092
TensorShape      6058   | Tensor           7575     | Tensor           9092

2 个答案:

答案 0 :(得分:1)

最后,我发现这是由调用太多tf.convert_to_tensor引起的,每次调用该函数都会在tensorflow图中生成一个新节点,这需要一些内存。

要解决此问题,只需使用tf.placeholder来提供数据。 此外,tensorflow v1.3添加了一个新方法tf.contrib.predictor来执行此操作。阅读https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/predictor

中的更多内容

答案 1 :(得分:0)

你可以发布你的objgraph结果吗?如果这是一个张量流问题或一般的python问题,它将有助于澄清。