我从https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz下载了https://www.tensorflow.org/guide/extend/model_files个示例,并尝试将其转换为优化图:
import tensorflow as tf
from tensorflow.python.tools import optimize_for_inference_lib
input_graph_path = './inception_v3_2016_08_28_frozen.pb'
output_optimized_graph_name = './model_optimized.pb'
input_node_names = ['input']
output_node_names = ['InceptionV3/Predictions/Reshape_1']
input_graph_def = tf.GraphDef()
with tf.gfile.Open(input_graph_path, "rb") as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
input_node_names,
output_node_names,
tf.float32.as_datatype_enum)
f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
f.write(output_graph_def.SerializeToString())
然后使用bazel-bin/tensorflow/tools/graph_transforms/summarize_graph
检查原始图和优化图,它们大致相同:
原文:
Found 1 possible inputs: (name=input, type=float(1), shape=[1,299,299,3])
No variables spotted.
Found 1 possible outputs: (name=InceptionV3/Predictions/Reshape_1, op=Reshape)
Found 23853946 (23.85M) const parameters, 0 (0) variable parameters, and 0 control_edges
Op types used: 489 Const, 379 Identity, 188 Mul, 188 Add, 95 Conv2D, 94 Sub, 94 Rsqrt, 94 Relu, 15 ConcatV2, 10 AvgPool, 4 MaxPool, 2 Reshape, 1 BiasAdd, 1 Softmax, 1 Squeeze, 1 Placeholder
优化:
Found 1 possible inputs: (name=input, type=float(1), shape=None)
No variables spotted.
Found 1 possible outputs: (name=InceptionV3/Predictions/Reshape_1, op=Reshape)
Found 23853946 (23.85M) const parameters, 0 (0) variable parameters, and 0 control_edges
Op types used: 489 Const, 188 Mul, 188 Add, 95 Conv2D, 94 Sub, 94 Relu, 94 Rsqrt, 15 ConcatV2, 10 AvgPool, 4 MaxPool, 2 Reshape, 1 BiasAdd, 1 Softmax, 1 Squeeze, 1 Placeholder
为什么转换后输入形状会掉线?是否可以保存形状或编辑图形以在转换后添加形状(这是BN折叠https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference_lib.py#L28的注释)?
为什么批次规范未在转换层中融合? 使用netron查看图形可视化,我看到我怀疑是BN的这类操作。
更新:
关于BN,似乎仅支持BatchNormWithGlobalNormalization
和FusedBatchNorm