Tensorflow - optimize_for_inference_lib.optimize_for_inference丢失由tf.identity指定的输出节点

时间:2017-06-19 08:54:58

标签: tensorflow

TensorFlow VERSION 1.1.0

我通过tf.identity将["输入","输出"]指定为我模型的输入和输出节点。

冻结后,我可以找到"输入","输出"在graph_def.node。

优化后,"输入"仍然存在,但"输出"失去了。

  

Found in frozen Traceback (most recent call last): File "/home/xiusir/.virtualenvs/ir/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 480, in import_graph_def ret.append(name_to_op[name]) KeyError: 'output' During handling of the above exception, another exception occurred: Traceback (most recent call last): File "graph.py", line 43, in <module> output_tensor = tf.import_graph_def(graph_def, name='', return_elements=[MODEL_OUTPUT_TENSOR_NAME]) File "/home/xiusir/.virtualenvs/ir/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 483, in import_graph_def 'Requested return_element %r not found in graph_def.' % name) ValueError: Requested return_element 'output' not found in graph_def.

train.py

def inference(images):
  images = tf.identity(images, 'input')
  ... ...
  softmax_linear = tf.identity(softmax_linear, 'output')
  return softmax_linear

package.py - 生成模型文件

# -*- coding: utf-8 -*-
# Preparing a TF model for usage in Android
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib

MODEL_NAME = 'carc34'
path='.'
input_graph_path = '%s/graph.pbtxt' % path
checkpoint_path = '%s/model.ckpt-100000' % path
input_saver_def_path = ""
input_binary = False
input_node_names = "input"
output_node_names = "output,softmax_linear/softmax_linear"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = 'frozen_'+MODEL_NAME+'.pb'
output_optimized_graph_name = 'optimized_'+MODEL_NAME+'.pb'
clear_devices = True

# Freeze the graph
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                      input_binary, checkpoint_path, output_node_names,
                      restore_op_name, filename_tensor_name,
                      output_frozen_graph_name, clear_devices, "")

# Optimize for inference
input_graph_def = tf.GraphDef()
with tf.gfile.Open(output_frozen_graph_name, "rb") as f:
    data = f.read()
    input_graph_def.ParseFromString(data)
    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def,
            input_node_names.split(","),  # an array of the input node(s)
            output_node_names.split(","), # an array of the output nodes
            tf.float32.as_datatype_enum)

# Save the optimized graph
f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
f.write(output_graph_def.SerializeToString())

show_nodes.py - 以冻结和优化的方式打印所有节点

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import gfile
from tensorflow.python.util import compat

MODEL_OUTPUT_TENSOR_NAME = 'output'
with tf.Graph().as_default() as graph:
  model_filename = os.path.join('.', 'frozen_carc34.pb')
  with gfile.FastGFile(model_filename, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    ##for n in graph_def.node:
    ##  print (n.name)
    output_tensor = tf.import_graph_def(graph_def, name='', return_elements=[MODEL_OUTPUT_TENSOR_NAME])
    print ("Found in frozen")

  model_filename = os.path.join('.', 'optimized_carc34.pb')
  with gfile.FastGFile(model_filename, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    ##for n in graph_def.node:
    ##  print (n.name)
    output_tensor = tf.import_graph_def(graph_def, name='', return_elements=[MODEL_OUTPUT_TENSOR_NAME])
    print ("Found in optimized")

1 个答案:

答案 0 :(得分:1)

尝试将input_node_namesoutput_node_names更改为字符串列表。这为我解决了类似的问题。

冻结或剥离未使用节点的图形也有可能摆脱看起来像身份节点的输出。