我试图在pbtxt文件中冻结一个包含batchnorm图层的检查点(tf 1.1.0)。 为此,请关注这些帖子和问题:
我使用这个功能:
freeze_and_prune_graph(model_path_and_name, output_file=None):
"""
freezes a model trained and saved by the trainer by :
- extracting the trainable variables between input_node and output_node
- turning them to constants
- changing the 1rst dim of input_node to None
-saving the resulting graph as a single .pb file
:param model_path_and_name: must finish by .ckpt, and the checkpoint must be composed of
3+ files : .ckpt.index, .ckpt.meta, and .ckpt.data-0000X-of-0000Y
:param model_path_and_name: path to the trained model
:param output_file: file to save to. If None, model_path_and_name.[-ckpt][+pb]
:return: None
"""
config_proto = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config_proto) as sess:
new_saver = tf.train.import_meta_graph(model_path_and_name + '.meta', clear_devices=True)
tf.get_default_session().run(tf.global_variables_initializer())
tf.get_default_session().run(tf.local_variables_initializer())
new_saver.restore(sess, model_path_and_name)
# get graph definition
gd = sess.graph.as_graph_def()
# fix batch norm nodes
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
# tf.get_collection() returns a list. In this example we only want the
input_node = sess.graph.get_tensor_by_name('input_node:0')
new_shape = [None] + input_node.get_shape().as_list()[1:]
trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
new_graph_def = tf.graph_util.convert_variables_to_constants(sess, gd, ["output_node"],
variable_names_whitelist=[t.name[:-2] for t in trainables] + ['output_node'])
for node in new_graph_def.node:
if node.name == 'input_node':
node.attr['shape'].CopyFrom(attr_value_pb2.AttrValue(shape=tf.TensorShape(new_shape).as_proto()))
break
with tf.gfile.GFile(output_file, "wb") as f:
f.write(new_graph_def.SerializeToString())
print("{0} / {1} ops in the final graph.".format(len(new_graph_def.node), len(sess.graph.as_graph_def().node)))
这很顺利,并使用以下输出创建pbtxt文件:
Converted 201 variables to const ops.
5287 / 41028 ops in the final graph.
然后我尝试使用此函数加载pbtxt
模型:
def load_frozen_graph(frozen_graph_file):
"""
loads a graph frozen via freeze_and_prune_graph and returns the graph, its input placeholder and output tensor
:param frozen_graph_file: .pb file to load
:return: tf.graph, tf.placeholder, tf.tensor
"""
# We load the protobuf file from the disk and parse it to retrieve the
# unserialized graph_def
with tf.gfile.GFile(frozen_graph_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Then, we can use again a convenient built-in function to import a graph_def into the
# current default Graph
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="prefix",
op_dict=None,
producer_op_list=None
)
input_images_placeholder = graph.get_tensor_by_name('prefix/input_node:0')
input_phase_placeholder = None
try:
input_phase_placeholder = graph.get_tensor_by_name('prefix/phase:0')
except KeyError:
pass
output = graph.get_tensor_by_name('prefix/output_node:0')
return graph, input_images_placeholder, input_phase_placeholder, output
使用以下代码段:
graph, input_images_placeholder, is_training_placeholder, output = load_frozen_graph(model_pbtxt)
sess = tf.Session(config=tf_config, graph=graph)
feed_dict = {input_images_placeholder: prepared_input}
if is_training_placeholder is not None:
feed_dict[is_training_placeholder] = False
ret = sess.run([output], feed_dict=feed_dict)
但是,这会导致以下错误:
FailedPreconditionError (see above for traceback):
Attempting to use uninitialized value prefix/conv0/BatchNorm/batch_normalization/moving_mean
[
[
Node: prefix/conv0/BatchNorm/batch_normalization/moving_mean/read = Identity[
T=DT_FLOAT,
_class=["loc:@prefix/conv0/BatchNorm/batch_normalization/moving_mean"],
_device="/job:localhost/replica:0/task:0/gpu:0"
](prefix/conv0/BatchNorm/batch_normalization/moving_mean)
]
]
[
[
Node: prefix/output_node/_381 = _Recv[
client_terminated=false,
recv_device="/job:localhost/replica:0/task:0/cpu:0",
send_device="/job:localhost/replica:0/task:0/gpu:0",
send_device_incarnation=1,
tensor_name="edge_2447_prefix/output_node",
tensor_type=DT_FLOAT,
_device="/job:localhost/replica:0/task:0/cpu:0"
]()
]
]
以下问题: TensorFlow: "Attempting to use uninitialized value" in variable initialization 我尝试初始化变量:
graph, input_images_placeholder, is_training_placeholder, output = load_frozen_graph(model_pbtxt)
sess = tf.Session(config=tf_config, graph=graph)
init = [tf.global_variables_initializer(), tf.local_variables_initializer()]
sess.run(init)
feed_dict = {input_images_placeholder: prepared_input}
if is_training_placeholder is not None:
feed_dict[is_training_placeholder] = False
ret = sess.run([self.output], feed_dict=feed_dict)
然而,这会将错误更改为:
ValueError: Fetch argument <tf.Operation 'init' type=NoOp> cannot be interpreted as a Tensor.
(Operation name: "init" op: "NoOp" is not an element of this graph.)
似乎表明没有需要初始化的变量。
我错过了什么?如何冻结并重新加载batch_normalization图层的相关值?