使用tf-coreml将Tensorflow转换为CoreML

时间:2019-10-01 17:43:04

标签: python-3.x tensorflow tensorflow-lite coreml batch-normalization

我有一个多输入网络,该网络使用tf.bool tf.placeholder来管理在培训和验证/测试中如何执行批量标准化。 我一直在尝试通过CoreML库将此经过训练的模型转换为tf-coreml,但没有成功,出现以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value

我了解到此错误表明某些节点缺少值,因此转换器可以执行模型。我也理解此错误与控制流操作有关(链接到创建诸如SwitchMerge之类的批处理规范化方法)。 source code显示如下:

def testSwitchDeadBranch(self):
    with self.cached_session():
      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = ops.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      dead_branch = array_ops.identity(switch_op[0])

      with self.assertRaisesWithPredicateMatch(
          errors_impl.InvalidArgumentError,
          lambda e: "Retval[0] does not have value" in str(e)):
        self.evaluate(dead_branch)

请注意,我的错误是Retval[26](我得到了[24]等),而不是Retval[0]。我假设它测试了Switch“死分支”,该分支应该是未使用的分支来进行推断。该代码对Merge“死分支”也是如此。

我是否缺少任何可能导致此错误的细节(当然,这不是我在转换期间遇到的第一个错误)?推断的方式?批处理规范化的实现方式?模型的保存方式?

我到目前为止所做的:

  • 我正在使用Tensorflow 1.14.0
  • 我知道tf.layers.batch_normalization创建的SwitchMerge操作与CoreML不兼容
  • 我曾尝试将类似问题转换为Tensorflow Lite
  • 我一直遵循Facenet(此模型使用相同的tf.bool逻辑进行训练,验证和测试)转换过程,但未成功
  • 我已经尝试过GraphTransforms
  • 我尝试使用脚本来删除/修改控制流
  • 我创建了单独的图表,以避免没有成功的额外操作

注意:我已经摘录了大部分代码以发布此问题。

这是在卷积块内实现批量归一化的方式。

training = tf.placeholder(tf.bool, shape = (), name = 'training')

def conv_layer(input, kernelSize, nFilters, poolSize, stride, input_channels = 1, name = 'conv'):
        with tf.name_scope(name):
        shape = [kernelSize, kernelSize, input_channels, nFilters]
        weights = new_weights(shape = shape)        biases = new_biases(length = nFilters)
        conv = tf.nn.conv2d(input, weights, strides = [1, 2, 2, 1], padding = 'SAME', name = 'convL')
        conv += biases
        pool = tf.reduce_max(conv, reduction_indices=[3], keep_dims=True, name = 'pool') 
       pool = tf.nn.max_pool(conv, ksize = [1, poolSize, poolSize, 1], strides = shape, padding = 'SAME')
        bnorm = tf.layers.batch_normalization(pool, training = training, center = True, scale = True, fused = False, reuse= False)
        act = tf.nn.relu(bnorm)
        return act

下面是训练和保存模型的代码。

saver = tf.train.Saver()

    with tf.Session(config = config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(init_train_op)

        for epoch in range(MAX_EPOCHS):

            for step in range(10):

                l, _, se = sess.run(
                    [loss, train_op, mean_squared_error],
                     feed_dict = {training: True})

            print('\nRunning validation operation...')

            sess.run(init_val_op)
            for _ in range(10):
                val_out, val_l, val_se = sess.run(
                    [out, val_loss, val_mean_squared_error],
                    feed_dict = {training: False})

            sess.run(init_train_op) # switch back to training set

        #Save model
        print('Saving Model...\n')
        saver.save(sess, join(saveDir, './model_saver_validation'.format(modelIndex)), write_meta_graph = True)

下面是用于加载,更新输入,执行推理和冻结模型的代码。

# Dummy data for inference
b = np.zeros((1, 80, 160, 1), np.float32)
ill = np.ones((1,3), np.float32)
is_train = False

def freeze():
    with tf.Graph().as_default():
        with tf.Session() as sess:
            bIn = tf.placeholder(dtype=tf.float32, shape=[
                             1, 80, 160, 1], name='bIn')
            illumIn = tf.placeholder(dtype=tf.float32, shape=[
                                     1, 3], name='illumIn')
            training = tf.placeholder(tf.bool, shape=(), name = 'training')

            # Load the model metagraph and checkpoint
            meta_file = meta_graph #.meta file from saver.save()
            ckpt_file = checkpoint_file #checkpoint file

            # Load graph to redirect inputs from iterator to expected inputs
            saver = tf.train.import_meta_graph(meta_file, input_map={
                'IteratorGetNext:0': bIn,
                'IteratorGetNext:3': illumIn,
                'training:0': training},  clear_devices = True)

            tf.get_default_session().run(tf.global_variables_initializer())
            tf.get_default_session().run(tf.local_variables_initializer())
            saver.restore(tf.get_default_session(), ckpt_file)

            pred = tf.get_default_graph().get_tensor_by_name('Out:0')

            tf.get_default_session().run(pred, feed_dict={'bIn:0': b, 'poseIn:0': po, 'training:0': is_train})

            # Retrieve the protobuf graph definition and fix the batch norm nodes
            input_graph_def = sess.graph.as_graph_def()

            # Freeze the graph def
            output_graph_def = freeze_graph_def(
                sess, input_graph_def, output_node_names)

        # Serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(frozen_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())

freeze()

下面是要转换为CoreML的代码。

tfcoreml.convert(
    tf_model_path=frozen_graph,
    mlmodel_path='./coreml_model.mlmodel',
    output_feature_names=['Out:0'],
    input_name_shape_dict={
        'bIn:0': [1, 80, 160, 1],
        'illumIn:0': [1, 3], 
        'training:0': []})

以下是tf-coreml引发的错误。

Loading the TF graph...
Graph Loaded.
Collecting all the 'Const' ops from the graph, by running it....

Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tf2opencv.py", line 392, in <module>
    'illumIn:0': [1, 3], 'poseIn:0': [1, 16], 'training:0': []})
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tfcoreml/_tf_coreml_converter.py", line 586, in convert
    custom_conversion_functions=custom_conversion_functions)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tfcoreml/_tf_coreml_converter.py", line 243, in _convert_pb_to_mlmodel
    tensors_evaluated = sess.run(tensors, feed_dict=input_feed_dict)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value

0 个答案:

没有答案