我正在尝试使用TransformGraph为自定义LSTM模型生成八位量化图。如果我只使用quantze_weights,图表导入工作正常。应用quantize_nodes后,导入失败并显示错误,如下所示
ValueError:导入期间不存在的op的指定共置:lstm1 / lstm1 / cond / Switch_2中的lstm1 / lstm1 / BasicLSTMCellZeroState / zer
我用于量化的代码片段列在下面
from tensorflow.tools.graph_transforms import TransformGraph
import tensorflow as tf
input_names = ["inp/X"]
output_names = ["out/Softmax"]
#transforms = ["quantize_weights", "quantize_nodes"]
#transforms = ["quantize_weights"]
transforms = ["add_default_attributes",
"strip_unused_nodes",
"remove_nodes(op=Identity, op=CheckNumerics)",
#"fold_constants(ignore_errors=true)",
"fold_batch_norms",
"fold_old_batch_norms",
"quantize_weights",
"quantize_nodes",
"sort_by_execution_order"]
#output_graph_path="/tmp/fixed.pb"
output_graph_path="/tmp/output_graph.pb"
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with tf.Session() as sess:
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
transformed_graph_def = TransformGraph(output_graph_def, input_names,
output_names, transforms)
tf.train.write_graph(transformed_graph_def, '/tmp', 'quantized.pb', as_text=False)
我也尝试过使用quantize_graph.py,它总是导致https://github.com/tensorflow/tensorflow/issues/8025中的keyerror。我相信这段代码不再维护。任何人都可以指出如何调试此问题。