如何将线性分类器导出为TFLITE格式?

时间:2018-08-01 10:32:09

标签: tensorflow tensorflow-lite

我正在尝试根据可导出DNN分类器的示例转换线性分类器:

print("\n====== classifier model_dir, latest_checkpoint ===========")
print(classifier.model_dir)
print(classifier.latest_checkpoint())
debug = False

with tf.Session() as sess:
    # First let's load meta graph and restore weights
    latest_checkpoint_path = classifier.latest_checkpoint()
    saver = tf.train.import_meta_graph(latest_checkpoint_path + '.meta')
    saver.restore(sess, latest_checkpoint_path)
# Get the input and output tensors needed for toco.
# These were determined based on the debugging info printed / saved below.
input_tensor = sess.graph.get_tensor_by_name("dnn/input_from_feature_columns/input_layer/concat:0")
input_tensor.set_shape([1, 10])
out_tensor = sess.graph.get_tensor_by_name("dnn/logits/BiasAdd:0")
out_tensor.set_shape([1, 5])

# Pass the output node name we are interested in.
# Based on the debugging info printed / saved below, pulled out the
# name of the node for the logits (before the softmax is applied).
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
    sess, sess.graph_def, output_node_names=["dnn/logits/BiasAdd"])

if debug is True:
    print("\nORIGINAL GRAPH DEF Ops ===========================================")
    ops = sess.graph.get_operations()
    for op in ops:
        if "BiasAdd" in op.name or "input_layer" in op.name:
            print([op.name, op.values()])
    # save original graphdef to text file
    with open("estimator_graph.pbtxt", "w") as fp:
        fp.write(str(sess.graph_def))

    print("\nFROZEN GRAPH DEF Nodes ===========================================")
    for node in frozen_graph_def.node:
        print(node.name)
    # save frozen graph def to text file
    with open("estimator_frozen_graph.pbtxt", "w") as fp:
        fp.write(str(frozen_graph_def))

tflite_model = tf.contrib.lite.toco_convert(frozen_graph_def, [input_tensor], [out_tensor])
open("estimator_model.tflite", "wb").write(tflite_model)

但是我不知道在本节中使用哪个张量:

input_tensor = sess.graph.get_tensor_by_name("dnn/input_from_feature_columns/input_layer/concat:0")
    input_tensor.set_shape([1, 10])
    out_tensor = sess.graph.get_tensor_by_name("dnn/logits/BiasAdd:0")
    out_tensor.set_shape([1, 3])

我尝试作为输入张量: 线性/线性模型/线性模型/加权和:0 形状:1.5 (因为找不到适合1,10的张量)

并作为输出张量,具有:线性/头部/预测/概率:0 1,5形状

但是当我尝试在android设备中使用它时,输出张量的形状不再是1,5而是1,10 而且我不知道如何解释这个结果,也许问题是我不知道选择哪个张量作为toco_convert函数的输入

0 个答案:

没有答案