在这里,我通过其提供的对象检测API获得了一个tensorflow对象检测模型。对于某些Nvidia GPU(V系列,P100等),它们支持float16以便进行更快的训练和推理,将权重从float32截断到float16似乎是一个不错的选择。当然,混合精度训练是精度损失很小的最佳选择,但是训练后量化几乎不需要额外的工作,转换后的模型可用于评估推理速度。所以无论如何,我决定先对其进行测试。
我使用以下代码转换模型检测模型(SSD-mobilenetV2 FPN),但发生了一些有线事情:在我的计算机(GTX 1080ti)上加载检测模型几乎要花费5分钟以上的时间。在CPU上加载模型是相同的。我怀疑加载缓慢可能是由于硬件对float16的不良支持。然后,我尝试分类模型(resnet 18)。一切正常!加载速度快,输出几乎相同。有没有人尝试过混合精度训练中的float16对象检测模型并遭受类似的问题?任何建议表示赞赏!
此处的要点链接:https://gist.github.com/CasiaFan/5eebd085fff4aa0267e0132046b80437
import tensorflow as tf
from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from google.protobuf import text_format
import numpy as np
# object detection api input and output nodes
input_name = "image_tensor"
output_names = ["detection_boxes", "detection_classes", "detection_scores", "num_detections"]
# Const should be float32 in object detection api during nms (see here: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v4.html)
keep_fp32_node_name = ["Postprocessor/BatchMultiClassNonMaxSuppression/MultiClassNonMaxSuppression/non_max_suppression/iou_threshold",
"Postprocessor/BatchMultiClassNonMaxSuppression/MultiClassNonMaxSuppression/non_max_suppression/score_threshold"]
def load_graph(model_path):
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
if model_path.endswith("pb"):
with open(model_path, "rb") as f:
graph_def.ParseFromString(f.read())
else:
with open(model_path, "r") as pf:
text_format.Parse(pf.read(), graph_def)
tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=graph)
return sess
def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'):
"""
Rewrite FusedBatchNorm with FusedBatchNormV2 for reserve_space_1 and reserve_space_2 in FusedBatchNorm require float32 for
gradient calculation (See here: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fused-batch-norm)
"""
if target_type == 'fp16':
dtype = types_pb2.DT_HALF
elif target_type == 'fp64':
dtype = types_pb2.DT_DOUBLE
else:
dtype = types_pb2.DT_FLOAT
new_node = graph_def.node.add()
new_node.op = "FusedBatchNormV2"
new_node.name = node.name
new_node.input.extend(node.input)
new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
for attr in list(node.attr.keys()):
if attr == "T":
node.attr[attr].type = dtype
new_node.attr[attr].CopyFrom(node.attr[attr])
print("rewrite fused_batch_norm done!")
def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None):
if target_type == 'fp16':
dtype = types_pb2.DT_HALF
elif target_type == 'fp64':
dtype = types_pb2.DT_DOUBLE
else:
dtype = types_pb2.DT_FLOAT
source_sess = load_graph(model_path)
source_graph_def = source_sess.graph.as_graph_def()
target_graph_def = graph_pb2.GraphDef()
target_graph_def.versions.CopyFrom(source_graph_def.versions)
for node in source_graph_def.node:
# fused batch norm node
if node.op == "FusedBatchNorm":
rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type)
continue
# replicate node
new_node = target_graph_def.node.add()
new_node.op = node.op
new_node.name = node.name
new_node.input.extend(node.input)
attrs = list(node.attr.keys())
# keep batch norm params node
if ("BatchNorm" in node.name) or ('batch_normalization' in node.name):
for attr in attrs:
new_node.attr[attr].CopyFrom(node.attr[attr])
continue
# replace dtype in node attr with target dtype
for attr in attrs:
# keep special node in fp32
if node.name in keep_fp32_node_name:
new_node.attr[attr].CopyFrom(node.attr[attr])
continue
if node.attr[attr].type == types_pb2.DT_FLOAT:
# modify node dtype
node.attr[attr].type = dtype
if attr == "value":
tensor = node.attr[attr].tensor
if tensor.dtype == types_pb2.DT_FLOAT:
# if float_val exists
if tensor.float_val:
float_val = tf.make_ndarray(node.attr[attr].tensor)
new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype))
continue
# if tensor content exists
if tensor.tensor_content:
tensor_shape = [x.size for x in tensor.tensor_shape.dim]
tensor_weights = tf.make_ndarray(tensor)
# reshape tensor
tensor_weights = np.reshape(tensor_weights, tensor_shape)
tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype)
new_node.attr[attr].tensor.CopyFrom(tensor_proto)
continue
new_node.attr[attr].CopyFrom(node.attr[attr])
# transform graph
if output_names:
if not input_name:
input_name = []
transforms = ["strip_unused_nodes"]
target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms)
# write graph_def to model
tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text)
print("Converting done ...")
save_path = "test"
name = "test.pb"
as_text = False
target_type = 'fp16'
convert_graph_to_fp16(model_path, save_path, name, as_text=as_text, target_type=target_type, input_name=input_name, output_names=output_names)
# test loading
# ISSUE: loading detection model is extremely slow while loading classification model is normal
sess = load_graph(save_path+"/"+name)