量化后的Tensorflow import_graph_def导致错误

时间:2017-12-09 14:55:32

标签: python tensorflow

我正在尝试使用TransformGraph为自定义LSTM模型生成八位量化图。如果我只使用quantze_weights,图表导入工作正常。应用quantize_nodes后,导入失败并显示错误,如下所示

  

ValueError:导入期间不存在的op的指定共置:lstm1 / lstm1 / cond / Switch_2中的lstm1 / lstm1 / BasicLSTMCellZeroState / zer

我用于量化的代码片段列在下面

from tensorflow.tools.graph_transforms import TransformGraph
import tensorflow as tf

input_names = ["inp/X"]
output_names = ["out/Softmax"]
#transforms = ["quantize_weights", "quantize_nodes"]
#transforms = ["quantize_weights"]
transforms = ["add_default_attributes",
"strip_unused_nodes",
"remove_nodes(op=Identity, op=CheckNumerics)",
#"fold_constants(ignore_errors=true)",
"fold_batch_norms",
"fold_old_batch_norms",
"quantize_weights",
"quantize_nodes",
"sort_by_execution_order"]
#output_graph_path="/tmp/fixed.pb"
output_graph_path="/tmp/output_graph.pb"
with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with tf.Session() as sess:
            with open(output_graph_path, "rb") as f:

                output_graph_def.ParseFromString(f.read())
                _ = tf.import_graph_def(output_graph_def, name="")

                transformed_graph_def = TransformGraph(output_graph_def, input_names,
                                       output_names, transforms)

                tf.train.write_graph(transformed_graph_def, '/tmp', 'quantized.pb', as_text=False)

我也尝试过使用quantize_graph.py,它总是导致https://github.com/tensorflow/tensorflow/issues/8025中的keyerror。我相信这段代码不再维护。任何人都可以指出如何调试此问题。

0 个答案:

没有答案