如何在python中加载冻结张量流图时减少内存使用?

时间:2018-06-15 05:33:22

标签: python tensorflow

我有一个预先训练过的Tensorflow模型已被冻结(所有变量都转换为常量)并作为序列化GraphDef保存到磁盘。我目前正在使用以下Python代码从磁盘读取模型:

graph = tf.Graph()
with graph.as_default():
    with open(filename, 'rb') as f:
        graph_def = tf.GraphDef.FromString(f.read())
    tf.import_graph_def(graph_def, name='')
sess = tf.Session(graph=graph)

一切都按预期工作,但内存使用率高于要求。 GraphDef中有大约400MB的权重,因此保留多个数据副本会浪费大量的RAM。

查看tf.import_graph_def的源代码,我发现了以下内容:

with graph._lock:  # pylint: disable=protected-access
      with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:

换句话说,刚刚从字符串解析的图形def在这里被转换回字符串!

是否有官方方法直接将序列化的GraphDef直接输入C / C ++ Tensorflow后端?我不需要重命名节点或合并多个图形,只需从磁盘加载已保存的图形。

谢谢!

0 个答案:

没有答案