我有一个多输入网络,该网络使用tf.bool
tf.placeholder
来管理在培训和验证/测试中如何执行批量标准化。
我一直在尝试通过CoreML
库将此经过训练的模型转换为tf-coreml
,但没有成功,出现以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value
我了解到此错误表明某些节点缺少值,因此转换器可以执行模型。我也理解此错误与控制流操作有关(链接到创建诸如Switch
和Merge
之类的批处理规范化方法)。 source code显示如下:
def testSwitchDeadBranch(self):
with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
ports = ops.convert_to_tensor(True, name="ports")
switch_op = control_flow_ops.switch(data, ports)
dead_branch = array_ops.identity(switch_op[0])
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
lambda e: "Retval[0] does not have value" in str(e)):
self.evaluate(dead_branch)
请注意,我的错误是Retval[26]
(我得到了[24]等),而不是Retval[0]
。我假设它测试了Switch
“死分支”,该分支应该是未使用的分支来进行推断。该代码对Merge
“死分支”也是如此。
我是否缺少任何可能导致此错误的细节(当然,这不是我在转换期间遇到的第一个错误)?推断的方式?批处理规范化的实现方式?模型的保存方式?
我到目前为止所做的:
Tensorflow 1.14.0
tf.layers.batch_normalization
创建的Switch
和Merge
操作与CoreML不兼容Tensorflow Lite
Facenet
(此模型使用相同的tf.bool
逻辑进行训练,验证和测试)转换过程,但未成功GraphTransforms
库注意:我已经摘录了大部分代码以发布此问题。
这是在卷积块内实现批量归一化的方式。
training = tf.placeholder(tf.bool, shape = (), name = 'training')
def conv_layer(input, kernelSize, nFilters, poolSize, stride, input_channels = 1, name = 'conv'):
with tf.name_scope(name):
shape = [kernelSize, kernelSize, input_channels, nFilters]
weights = new_weights(shape = shape) biases = new_biases(length = nFilters)
conv = tf.nn.conv2d(input, weights, strides = [1, 2, 2, 1], padding = 'SAME', name = 'convL')
conv += biases
pool = tf.reduce_max(conv, reduction_indices=[3], keep_dims=True, name = 'pool')
pool = tf.nn.max_pool(conv, ksize = [1, poolSize, poolSize, 1], strides = shape, padding = 'SAME')
bnorm = tf.layers.batch_normalization(pool, training = training, center = True, scale = True, fused = False, reuse= False)
act = tf.nn.relu(bnorm)
return act
下面是训练和保存模型的代码。
saver = tf.train.Saver()
with tf.Session(config = config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(init_train_op)
for epoch in range(MAX_EPOCHS):
for step in range(10):
l, _, se = sess.run(
[loss, train_op, mean_squared_error],
feed_dict = {training: True})
print('\nRunning validation operation...')
sess.run(init_val_op)
for _ in range(10):
val_out, val_l, val_se = sess.run(
[out, val_loss, val_mean_squared_error],
feed_dict = {training: False})
sess.run(init_train_op) # switch back to training set
#Save model
print('Saving Model...\n')
saver.save(sess, join(saveDir, './model_saver_validation'.format(modelIndex)), write_meta_graph = True)
下面是用于加载,更新输入,执行推理和冻结模型的代码。
# Dummy data for inference
b = np.zeros((1, 80, 160, 1), np.float32)
ill = np.ones((1,3), np.float32)
is_train = False
def freeze():
with tf.Graph().as_default():
with tf.Session() as sess:
bIn = tf.placeholder(dtype=tf.float32, shape=[
1, 80, 160, 1], name='bIn')
illumIn = tf.placeholder(dtype=tf.float32, shape=[
1, 3], name='illumIn')
training = tf.placeholder(tf.bool, shape=(), name = 'training')
# Load the model metagraph and checkpoint
meta_file = meta_graph #.meta file from saver.save()
ckpt_file = checkpoint_file #checkpoint file
# Load graph to redirect inputs from iterator to expected inputs
saver = tf.train.import_meta_graph(meta_file, input_map={
'IteratorGetNext:0': bIn,
'IteratorGetNext:3': illumIn,
'training:0': training}, clear_devices = True)
tf.get_default_session().run(tf.global_variables_initializer())
tf.get_default_session().run(tf.local_variables_initializer())
saver.restore(tf.get_default_session(), ckpt_file)
pred = tf.get_default_graph().get_tensor_by_name('Out:0')
tf.get_default_session().run(pred, feed_dict={'bIn:0': b, 'poseIn:0': po, 'training:0': is_train})
# Retrieve the protobuf graph definition and fix the batch norm nodes
input_graph_def = sess.graph.as_graph_def()
# Freeze the graph def
output_graph_def = freeze_graph_def(
sess, input_graph_def, output_node_names)
# Serialize and dump the output graph to the filesystem
with tf.gfile.GFile(frozen_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
freeze()
下面是要转换为CoreML的代码。
tfcoreml.convert(
tf_model_path=frozen_graph,
mlmodel_path='./coreml_model.mlmodel',
output_feature_names=['Out:0'],
input_name_shape_dict={
'bIn:0': [1, 80, 160, 1],
'illumIn:0': [1, 3],
'training:0': []})
以下是tf-coreml
引发的错误。
Loading the TF graph...
Graph Loaded.
Collecting all the 'Const' ops from the graph, by running it....
Traceback (most recent call last):
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
return fn(*args)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "tf2opencv.py", line 392, in <module>
'illumIn:0': [1, 3], 'poseIn:0': [1, 16], 'training:0': []})
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tfcoreml/_tf_coreml_converter.py", line 586, in convert
custom_conversion_functions=custom_conversion_functions)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tfcoreml/_tf_coreml_converter.py", line 243, in _convert_pb_to_mlmodel
tensors_evaluated = sess.run(tensors, feed_dict=input_feed_dict)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
run_metadata_ptr)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
run_metadata)
File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value