量化部署模型后如何进行

时间:2019-06-11 16:26:51

标签: tensorflow deployment

摘要

我最近正在研究一个大小约为500mb的模型,事实证明这在计算上非常昂贵。因此,我决定通过修剪here来减小尺寸。但是事实证明,该模型的大小仅改变了几mb。

但是,当我使用代码here(来自同一博客)对其进行量化时。我能够将尺寸减小到原始尺寸的大约1/4。但是,当我尝试使用这种缩小模型时,出现错误

InvalidArgumentError (see above for traceback): No OpKernel was registered to support Op 'Dequantize' used by node model/h0/mlp/c_proj/w (defined at app.py:25) with these attrs: [T=DT_QUINT8, mode="MIN_FIRST"]

Registered devices: [CPU]
Registered kernels:
<no registered kernels>
         [[node model/h0/mlp/c_proj/w (defined at app.py:25) ]]

经过一些研究,我得出的结论是,我需要向该图添加张量运算并从头开始对其进行重新训练。是这样吗我真的不愿意,因为我花了很多时间训练模型,所以如果有人可以提供一些见解,请告诉我!我尝试研究,但是找不到类似问题的人(或者我想念他们)。 代码如下。

代码

优化代码为

def optimize_graph(model_dir, graph_filename, output_node, input_nodes):
 transforms = [
   "remove_nodes(op=Identity)",
   "merge_duplicate_nodes",
   "strip_unused_nodes",
   "quantize_nodes",
   "quantize_weights"
]
input_names = input_nodes
  output_names = [output_node]
  if graph_filename is None:
    graph_def = get_graph_def_from_saved_model(model_dir)
  else:
    graph_def = get_graph_def_from_file(os.path.join(model_dir, graph_filename))
  optimized_graph_def = TransformGraph(
      graph_def,
      input_names,
      output_names,
      transforms)
  tf.train.write_graph(optimized_graph_def,
                      logdir=model_dir,
                      as_text=False,
                      name='optimized_model.pb')
  print('Graph optimized!')

进行优化。为了进行部署,我关注了博客here。并基本上使用烧瓶之类的方法加载了optimized_model.pb文件,如

app = Flask(__name__)
def load_graph(trained_model):   
    with tf.gfile.GFile(trained_model, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name=""
            )
    return graph
app.graph = load_graph('optimized_model.pb')

并通过获取张量来调用图

context = app.graph.get_tensor_by_name("Placeholder:0")
    temperature = app.graph.get_tensor_by_name("sample_sequence/while/ToFloat/x:0")
    output = app.graph.get_tensor_by_name("sample_sequence/while/Exit_3:0")

并在会话中打电话给他们

feed_dict = {temperature: temperature_input, context: context_tokens}
result = sess.run(output, feed_dict=feed_dict)
sess.close()

但是在此期间,会发生上述错误。

0 个答案:

没有答案