我已经在tensorflow上训练了一个pix2pix模型,模型已经以检查点的形式保存,并带有以下文件:
model-15000.meta
,model-15000.index
,model-15000.data-00000-of-00001
,graph.pbtxt
,checkpoint
。
现在,我想将其转换为protobuf文件(.pb)以进行部署。我遇到了freeze_graph.py脚本,但是我遇到了其中一个参数的问题,它是output_node_names
。
我尝试了几个图层名称,但是我收到以下错误:
AssertionError:生成器/ decoder_2 / batchnorm / scale / gradient不在图中
不确定如何找到output_node_names
答案 0 :(得分:1)
尝试以下代码将meta转换为pb文件:
import tensorflow as tf
#Step 1
#import the model metagraph
saver = tf.train.import_meta_graph('./model.meta', clear_devices=True)
#make that as the default graph
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
sess = tf.Session()
#now restore the variables
saver.restore(sess, "./model")
#Step 2
# Find the output name
graph = tf.get_default_graph()
for op in graph.get_operations():
print (op.name)
#Step 3
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
output_node_names="predictions_mod/Sigmoid"
output_graph_def = graph_util.convert_variables_to_constants(
sess, # The session
input_graph_def, # input_graph_def is useful for retrieving the nodes
output_node_names.split(",") )
#Step 4
#output folder
output_fld ='./'
#output pb file name
output_model_file = 'model.pb'
from tensorflow.python.framework import graph_io
#write the graph
graph_io.write_graph(output_graph_def, output_fld, output_model_file, as_text=False)
希望这有效!!!
答案 1 :(得分:0)
尝试冻结模型时遇到同样的问题。
AssertionError: pose:0 is not in graph
我正在使用此脚本打印所有张量名称,但我仍然收到错误。
import tensorflow as tf
from tensorflow.python.tools import inspect_checkpoint as chkp
meta_path = './data/trained_variables.ckpt.meta' # Your .meta file
with tf.Session() as sess:
# Restore the graph
saver = tf.train.import_meta_graph(meta_path)
# Load weights
saver.restore(sess,"/Users/me/Desktop/data/trained_variables.ckpt")
## Print tensors
chkp.print_tensors_in_checkpoint_file(file_name="/Users/me/Desktop/data/trained_variables.ckpt",
tensor_name='',
all_tensors=False,
all_tensor_names=True)
试一试,看看你是否能得到正确的名字。让我知道,我面临同样的问题。