如何在python中读取.tflite模型的各层参数

时间:2019-03-21 12:50:00

标签: tensorflow tensorflow-lite

我试图读取tflite模型并将所有图层的参数拉出。

我的步骤:

  1. 我通过运行生成了flatbuffers模型表示形式(请在之前构建flatc):

flatc -python tensorflow/tensorflow/lite/schema/schema.fbs

结果是tflite/文件夹,其中包含图层描述文件(*.py)和一些实用程序文件。

  1. 我成功加载了模型:

如果发生导入错误:将PYTHONPATH设置为指向tflite /所在的文件夹

from tflite.Model import Model
def read_tflite_model(file):
    buf = open(file, "rb").read()
    buf = bytearray(buf)
    model = Model.GetRootAsModel(buf, 0)
    return model
  1. 我将模型和节点参数部分拉出并堆叠在节点上进行迭代:

模型部分:

def print_model_info(model):
        version = model.Version()
        print("Model version:", version)
        description = model.Description().decode('utf-8')
        print("Description:", description)
        subgraph_len = model.SubgraphsLength()
        print("Subgraph length:", subgraph_len)

节点部分:

def print_nodes_info(model):
    # what does this 0 mean? should it always be zero?
    subgraph = model.Subgraphs(0)
    operators_len = subgraph.OperatorsLength()
    print('Operators length:', operators_len)

    from collections import deque
    nodes = deque(subgraph.InputsAsNumpy())

    STEP_N = 0
    MAX_STEPS = operators_len
    print("Nodes info:")
    while len(nodes) != 0 and STEP_N <= MAX_STEPS:
        print("MAX_STEPS={} STEP_N={}".format(MAX_STEPS, STEP_N))
        print("-" * 60)

        node_id = nodes.pop()
        print("Node id:", node_id)

        tensor = subgraph.Tensors(node_id)
        print("Node name:", tensor.Name().decode('utf-8'))
        print("Node shape:", tensor.ShapeAsNumpy())

        # which type is it? what does it mean?
        type_of_tensor = tensor.Type()
        print("Tensor type:", type_of_tensor)

        quantization = tensor.Quantization()
        min = quantization.MinAsNumpy()
        max = quantization.MaxAsNumpy()
        scale = quantization.ScaleAsNumpy()
        zero_point = quantization.ZeroPointAsNumpy()
        print("Quantization: ({}, {}), s={}, z={}".format(min, max, scale, zero_point))

        # I do not understand it again. what is j, that I set to 0 here?
        operator = subgraph.Operators(0)
        for i in operator.OutputsAsNumpy():
            nodes.appendleft(i)

        STEP_N += 1

    print("-"*60)

请向我介绍使用此API的文档或一些示例。

我的问题是:

  1. 我无法获得有关此API的文档

  2. 在我看来,不可能遍历Tensor对象,因为它没有Inputs和Outputs方法。 + subgraph.Operators(j=0)我不明白j在这里的含义。因此,我的循环经历了两个节点:一次输入和下一个输入。

  3. 肯定可以遍历Operator对象:

在这里我们遍历所有对象,但是我不知道如何映射Operator和Tensor。

def print_in_out_info_of_all_operators(model):
    # what does this 0 mean? should it always be zero?
    subgraph = model.Subgraphs(0)
    for i in range(subgraph.OperatorsLength()):
        operator = subgraph.Operators(i)
        print('Outputs', operator.OutputsAsNumpy())
        print('Inputs', operator.InputsAsNumpy())
  1. 我不了解如何将参数拉出Operator对象。 BuiltinOptions方法为我提供了Table对象,我不知道该映射到什么位置。

  2. subgraph = model.Subgraphs(0) 这个0是什么意思?应该总是零吗?显然没有,但这是什么?子图的ID?如果是这样-我很高兴。如果否,请尝试解释。

0 个答案:

没有答案