如何使用tflite_convert量化由tf.estimator.Estimator训练的网络?

时间:2018-06-26 12:23:21

标签: python tensorflow tensorflow-lite

tflite_convert是一个Python脚本,用于调用TOCO(TensorFlow Lite优化转换器)将文件从Tensorflow的格式转换为与tflite兼容的文件。

我正在尝试从我用Estimator训练的网络开始生成量化的TFlite模型。训练代码非常简单,我按照Fixed Point Quantization guide的要求添加了必要的修改以对模型进行微调:

def input_fn(mode, num_classes, batch_size=1):
  #[...]
  return {'images': images}, labels

def model_fn(features, labels, num_classes, mode):
  images = features['images']
  with tf.contrib.slim.arg_scope(net_arg_scope()):
    logits, end_points = build_net(...)

  if FLAGS.with_quantization:
    tf.logging.info("Applying quantization to the graph.")
    if mode == tf.estimator.ModeKeys.EVAL:
      tf.contrib.quantize.create_eval_graph()

  tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
  total_loss = tf.losses.get_total_loss()    #obtain the regularization losses as well

  if FLAGS.with_quantization:
    tf.logging.info("Applying quantization to the graph.")
    if mode == tf.estimator.ModeKeys.TRAIN:
      tf.contrib.quantize.create_training_graph()

  # Configure the training op, etc [...]
  return tf.estimator.EstimatorSpec(...)

def main(unused_argv):
  regex = FINETUNE_LAYER_RE if not FLAGS.with_quantization else '^((?!_quant).)*$'
  ws_settings = tf.estimator.WarmStartSettings(FLAGS.pretrained_checkpoint, regex)

  # Create the Estimator
  estimator = tf.estimator.Estimator(
    model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode),
    model_dir=FLAGS.model_dir,
    #config=run_config,
    warm_start_from=ws_settings)

  # Set up input functions for training and evaluation
  train_input_fn = lambda : input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, FLAGS.batch_size)
  eval_input_fn = lambda : input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, FLAGS.batch_size)

  #[...]

  train_spec = tf.estimator.TrainSpec(...)
  eval_spec = tf.estimator.EvalSpec(...)
  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

我遇到的第一个问题是,添加量化操作后,不可能简单地使用最新的检查点继续训练。这是因为量化增加了在检查点中找不到的额外变量。我解决了编写热启动规范的问题,该规范按名称过滤了所有新变量,并将训练中的最新检查点用作热启动检查点。

现在,我想生成一个评估图以保存(带有相关变量),然后通过tflite_convert脚本将其提供给TOCO。 我尝试转换每次评估后导出的SavedModel之一,但这会引发以下错误:

  

数组conv0_bn / FusedBatchNorm,它是Relu运算符的输入   产生输出数组cell_stem_0 / Relu,缺少最小/最大数据,   这是量化所必需的。定位非量化对象   输出格式,或将输入图更改为包含最小/最大   信息,或传递--default_ranges_min =和--default_ranges_max =   如果您不关心结果的准确性。   中止(核心已弃用)

我不知道如何获得正确的SavedModel或一对GraphDef +检查点(尽管最好使用SavedModel) 有没有人试图量化估计模型?您如何生成量化评估图?

0 个答案:

没有答案