我试图读取tflite模型并将所有图层的参数拉出。
我的步骤:
flatc -python tensorflow/tensorflow/lite/schema/schema.fbs
结果是tflite/
文件夹,其中包含图层描述文件(*.py
)和一些实用程序文件。
如果发生导入错误:将PYTHONPATH设置为指向tflite /所在的文件夹
from tflite.Model import Model
def read_tflite_model(file):
buf = open(file, "rb").read()
buf = bytearray(buf)
model = Model.GetRootAsModel(buf, 0)
return model
模型部分:
def print_model_info(model):
version = model.Version()
print("Model version:", version)
description = model.Description().decode('utf-8')
print("Description:", description)
subgraph_len = model.SubgraphsLength()
print("Subgraph length:", subgraph_len)
节点部分:
def print_nodes_info(model):
# what does this 0 mean? should it always be zero?
subgraph = model.Subgraphs(0)
operators_len = subgraph.OperatorsLength()
print('Operators length:', operators_len)
from collections import deque
nodes = deque(subgraph.InputsAsNumpy())
STEP_N = 0
MAX_STEPS = operators_len
print("Nodes info:")
while len(nodes) != 0 and STEP_N <= MAX_STEPS:
print("MAX_STEPS={} STEP_N={}".format(MAX_STEPS, STEP_N))
print("-" * 60)
node_id = nodes.pop()
print("Node id:", node_id)
tensor = subgraph.Tensors(node_id)
print("Node name:", tensor.Name().decode('utf-8'))
print("Node shape:", tensor.ShapeAsNumpy())
# which type is it? what does it mean?
type_of_tensor = tensor.Type()
print("Tensor type:", type_of_tensor)
quantization = tensor.Quantization()
min = quantization.MinAsNumpy()
max = quantization.MaxAsNumpy()
scale = quantization.ScaleAsNumpy()
zero_point = quantization.ZeroPointAsNumpy()
print("Quantization: ({}, {}), s={}, z={}".format(min, max, scale, zero_point))
# I do not understand it again. what is j, that I set to 0 here?
operator = subgraph.Operators(0)
for i in operator.OutputsAsNumpy():
nodes.appendleft(i)
STEP_N += 1
print("-"*60)
请向我介绍使用此API的文档或一些示例。
我的问题是:
我无法获得有关此API的文档
在我看来,不可能遍历Tensor对象,因为它没有Inputs和Outputs方法。 + subgraph.Operators(j=0)
我不明白j在这里的含义。因此,我的循环经历了两个节点:一次输入和下一个输入。
肯定可以遍历Operator对象:
在这里我们遍历所有对象,但是我不知道如何映射Operator和Tensor。
def print_in_out_info_of_all_operators(model):
# what does this 0 mean? should it always be zero?
subgraph = model.Subgraphs(0)
for i in range(subgraph.OperatorsLength()):
operator = subgraph.Operators(i)
print('Outputs', operator.OutputsAsNumpy())
print('Inputs', operator.InputsAsNumpy())
我不了解如何将参数拉出Operator对象。 BuiltinOptions方法为我提供了Table对象,我不知道该映射到什么位置。
subgraph = model.Subgraphs(0)
这个0是什么意思?应该总是零吗?显然没有,但这是什么?子图的ID?如果是这样-我很高兴。如果否,请尝试解释。