对新手问题很抱歉。我尝试过查找示例,但是大多数情况下,它是伪代码/代码段,我无法运行,而且很难说出是什么,尤其是在所有不同的TF工作流程中。
我正在尝试将经过预训练的TensorFlow保存模型ResNet-50 v2(fp32)转换为量化的TensorFlow Lite文件,并且存在两个问题:
即使有错误消息,该模型也一目了然,因此,我最关心的是批量大小的修改。
我尝试转换的已保存模型:
我尝试用来生成新的保存模型的Checkpoint数据和.py模型:
我如何在bash中将其转换为.tflite:
tflite_convert --output_file resnet_imagenet_v2_uint8_20181001.tflite --saved_model_dir . --post_training_quantize
这为64x224x224x3建立了合理的模型。尽管有错误,这对于Android / iOS来说也可以运行(没有尝试过),但是我试图在自定义平台上使用它进行实验。
我尝试使用目标输入形状生成脚本的保存脚本:
import tensorflow as tf
import numpy as np
from tensorflow.contrib.slim.nets import resnet_v2
def main():
# Directory containing resnet_v2_50.ckpt
ckpt_dir = "/media/resnet_v2_50_2017_04_14/"
with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope()):
input_tensor = tf.placeholder(tf.float32, shape=[1,224,224,3], name="input_tensor")
output_tensor = tf.placeholder(tf.float32, shape=[1,1000])
# Create model
# Generates errors for all Conv2D nodes like:
# 2018-10-19 16:41:41.393976: E tensorflow/core/framework/node_def_util.cc:110] Error in the node: {{node resnet_v2_50/conv1/Conv2D}} = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], padding="VALID", strides=[1, 2, 2, 1], use_cudnn_on_gpu=true](resnet_v2_50/Pad, resnet_v2_50/conv1/weights/read)
net, end_points = resnet_v2.resnet_v2_50(input_tensor, 1000)
# Load checkpoint data
sv = tf.train.Supervisor(logdir=ckpt_dir)
with sv.managed_session() as sess:
# Allows saving, but unexpected results:
# sess.graph._unsafe_unfinalize()
# Below call fails with:
# RuntimeError: Graph is finalized and cannot be modified.
tf.saved_model.simple_save(
sess,
"./export",
inputs={"input_tensor": input_tensor},
outputs={"resnet_v2_50/predictions/Softmax": output_tensor}
)
main()
我认为我做错了什么,但看起来与我在网上找到的示例很接近,所以我有点迷茫。使用sess.graph._unsafe_unfinalize()
可以运行simple_save(),在export /下创建一个.pb和variables /目录,但是当我在Netron(我发现的一个模型查看器)中查看该节点时,节点数比提供的Saved Model(批量大小64)我从tensorflow模型库下载。 3395和1930。
无论如何尝试转换此模型都会导致此错误:
$ tflite_convert --output_file export.tflite --saved_model_dir . --post_training_quantize
.
.
.
Traceback (most recent call last):
File "/home/tfuser/venv/bin/tflite_convert", line 11, in <module>
sys.exit(main())
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/tflite_convert.py", line 412, in main
app.run(main=run_main, argv=sys.argv[:1])
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/tflite_convert.py", line 408, in run_main
_convert_model(tflite_flags)
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/tflite_convert.py", line 162, in _convert_model
output_data = converter.convert()
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py", line 453, in convert
**converter_kwargs)
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/convert.py", line 370, in toco_convert_impl
input_data.SerializeToString())
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/convert.py", line 149, in toco_convert_protos
"TOCO failed see console for info.\n%s\n%s\n" % (stdout, stderr))
RuntimeError: TOCO failed see console for info.
b'2018-10-19 17:00:58.673690: F tensorflow/contrib/lite/toco/tooling_util.cc:886] Check failed: GetOpWithInput(model, input_array.name()) Specified input array "input_tensor" is not consumed by any op in this graph. Is it a typo? To silence this message, pass this flag: allow_nonexistent_arrays\n'
None
所以看来我对input_tensor的使用是错误的,并且正在创建很多额外的节点?
我的设置:
我在某些页面上寻求帮助:
关于如何实现这一目标的任何想法?我们将不胜感激并欢迎采用更好的策略或对脚本的修复。
次要问题:
sess.run(tf.global_variables_initializer())
?如果我尝试在上述脚本中调用它,则表示该图已完成。