如何在Tensorflow中量化Facenet的Inception-ResNet-v1模型?

时间:2018-07-09 12:38:20

标签: python tensorflow quantization

我想做什么

根据Tensorflow的facenet,我正在尝试创建graph_transform guide中使用的Inception-ResNet-v1模型的量化版本-不仅具有量化的权重,还具有量化的节点。

我尝试过的

使用在CASIA webface数据集上经过预训练的模型,我尝试通过添加以下行来finetune the model with fake quantization nodes

tf.contrib.quantize.create_training_graph(quant_delay=0)
在面部网训练脚本train_softmax.py中,

在计算了总损失(line 178)之后,在保存检查点(line 462)之前的下一行:

tf.contrib.quantize.create_eval_graph()

然后我以0.0005的学习率对预训练模型进行1000次迭代:

python3 src/train_softmax.py \
--logs_base_dir ~/logs/facenet/ \
--models_base_dir ${model_path} \
--data_dir ${casia_path} \
--image_size 160 \
--model_def models.inception_resnet_v1 \
--lfw_dir ${lfw_path} \
--optimizer ADAM \
--learning_rate -1 \
--max_nrof_epochs 150 \
--keep_probability 0.8 \
--random_crop \
--random_flip \
--use_fixed_image_standardization \
--learning_rate_schedule_file data/learning_rate_schedule_classifier_casia.txt \
--weight_decay 5e-4 \
--embedding_size 128 \
--lfw_distance_metric 1 \
--lfw_use_flipped_images \
--lfw_subtract_mean \
--validation_set_split_ratio 0.05 \
--validate_every_n_epochs 5 \
--prelogits_norm_loss_factor 5e-4 \
--center_loss_factor 2e-4 \
--gpu_memory_fraction 0.7 \
--pretrained_model ${model_path}/20180614-060325/model-20180614-060325.ckpt-90

到目前为止,太好了。接下来,我使用Facenet的freeze_graph冻结结果图:

python3 src/freeze_graph.py ${model_path}/20180709-100209/ ${model_path}/20180709-100209/model-20180709-100209-frozen.pb

最后,我尝试使用transform_graph创建一个完全量化的模型:

bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=${modelpath}/20180709-100209/model-20180709-100209-frozen.pb \
--out_graph=${modelpath}/20180709-100209/model-20180709-100209-quantized.pb \
--inputs='input,phase_train' \
--outputs='embeddings' \
--transforms='
  add_default_attributes
  strip_unused_nodes
  remove_nodes(op=Identity, op=CheckNumerics)
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms
  quantize_weights
  quantize_nodes
  strip_unused_nodes
  sort_by_execution_order'
INFO: Analysed target //tensorflow/tools/graph_transforms:transform_graph (0 packages loaded).
INFO: Found 1 target...
Target //tensorflow/tools/graph_transforms:transform_graph up-to-date:
  bazel-bin/tensorflow/tools/graph_transforms/transform_graph
INFO: Elapsed time: 0.369s, Critical Path: 0.00s
INFO: 0 processes.
INFO: Build completed successfully, 1 total action
2018-07-09 10:28:57.970978: I     tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying     add_default_attributes
2018-07-09 10:28:58.068712: I     tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying     strip_unused_nodes
2018-07-09 10:28:58.204184: I     tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying     remove_nodes
2018-07-09 10:29:22.175832: I     tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying     fold_constants
2018-07-09 10:29:22.241960: E     tensorflow/tools/graph_transforms/transform_graph.cc:333] fold_constants: Ignoring error Input 0 of node InceptionResnetV1/Repeat/block35_1/Conv2d_1x1/weights_quant/AssignMinLast was passed float from InceptionResnetV1/Repeat/block35_1/Conv2d_1x1/weights_quant/min:0 incompatible with expected float_ref.
2018-07-09 10:29:22.294469: I     tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying     fold_batch_norms
2018-07-09 10:29:22.421606: I     tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying     fold_old_batch_norms
2018-07-09 10:29:22.772485: I     tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying     quantize_weights
2018-07-09 10:29:23.224347: I     tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying     quantize_nodes
2018-07-09 10:29:25.297763: I     tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying     strip_unused_nodes
2018-07-09 10:29:25.403029: I     tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying     sort_by_execution_order

在fold_constants转换中生成错误。尝试运行结果模型会产生以下错误:

Traceback (most recent call last):
  File "benchmark_gpu.py", line 116, in <module>
    recognizer = FaceRecognizer(config)
  File "../facerecognizer.py", line 51, in __init__
    self.load_model()
  File "../facerecognizer.py", line 55, in load_model
facenet.load_model(self.model)
  File "../3rd-party/facenet/src/facenet.py", line 373, in load_model
tf.import_graph_def(graph_def, input_map=input_map, name='')
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/deprecation.py", line 432, in new_func
return func(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/importer.py", line 602, in import_graph_def
op_to_bind_to, node.name))
ValueError: Specified colocation to an op that does not exist during import: InceptionResnetV1/Repeat_1/block17_6/Conv2d_1x1/act_quant/min in InceptionResnetV1/Repeat_1/block17_6/Conv2d_1x1/act_quant/AssignMinEma/InceptionResnetV1/Repeat_1/block17_6/Conv2d_1x1/act_quant/min/AssignAdd/value

尝试解决问题

由fold_constants转换产生的错误表明,当op需要一个变量时会收到一个常量,因此我在将变量转换为freeze_graph.freeze_graph_def中的常量时尝试将所有量化节点添加到黑名单中:

# Get the list of important nodes
whitelist_names = []
blacklist_names = [] # <-- NEW
for node in input_graph_def.node:
    if (node.name.startswith('InceptionResnet') or node.name.startswith('embeddings') or 
        node.name.startswith('image_batch') or node.name.startswith('label_batch') or
        node.name.startswith('phase_train') or node.name.startswith('Logits')):
        whitelist_names.append(node.name)
    elif "quant" in node.name: # <-- NEW
        blacklist_names.append(node.name) # <-- NEW

# Replace all the variables in the graph with constants of the same values
output_graph_def = graph_util.convert_variables_to_constants(
    sess, input_graph_def, output_node_names.split(","),
    variable_names_whitelist=whitelist_names,
    variable_names_blacklist=blacklist_names) # <-- NEW

但是,在新冻结的模型上运行graph_transform会产生与以前相同的错误,而从graph_transform命令中删除fold_constants变换会导致与尝试运行模型时相同的错误。


我放错了create _ * _ graph()函数吗?我有误解其他吗?

0 个答案:

没有答案