我在python中使用TFCoreml将Tensorflow模型转换为CoreML,以便使用CoreML库在iOS设备上进行开发。
我使用以下python代码尝试将模型转换为CoreML。
将tfcoreml导入为tf_converter
tf_converter.convert(tf_model_path = 'frozen_inference_graph.pb',
mlmodel_path = 'ml_model.mlmodel',
output_feature_names = ['SemanticPredictions:0'],
input_name_shape_dict = {'ImageTensor:0' : [1, 512, 512, 3]})
这会出现以下错误:
Slice类型的OP缺少Shape Translator。
我进一步阅读了TFCoreml的文档,它指出不完全支持Slice,并且需要一些自定义转换代码才能起作用。在TFCoreml文档中,它建议将冻结的图分解为子图并分别进行转换,然后在转换后将它们合并在一起。
我更新了代码以使用自定义图层,但是我不太了解自定义转换功能的工作原理。
仅需要一些指针就可以开始了解如何编写这些自定义转换方法,从而可以解决将Tensorflow模型转换为CoreML的问题。
[编辑]
我进一步阅读了TFCoreml示例和文档,并针对此问题调整了我的解决方案。
import tfcoreml as tf_converter
def _convert_slice(**kwargs):
tf_op = kwargs["op"]
coreml_nn_builder = kwargs["nn_builder"]
constant_inputs = kwargs["constant_inputs"]
params = NeuralNetwork_pb2.CustomLayerParams()
params.className = 'Slice'
params.description = "Custom layer that corresponds to the slice TF op"
# get the value of begin
begin = constant_inputs.get(tf_op.inputs[1].name, [0,0,0,0])
size = constant_inputs.get(tf_op.inputs[2].name, [0,0,0,0])
# add begin and size as two repeated weight fields
begin_as_weights = params.weights.add()
begin_as_weights.floatValue.extend(map(float, begin))
size_as_weights = params.weights.add()
size_as_weights.floatValue.extend(map(float, size))
coreml_nn_builder.add_custom(name=tf_op.name,
input_names=[tf_op.inputs[0].name],
output_names=[tf_op.outputs[0].name],
custom_proto_spec=params)
coreml_model = tfcoreml.convert(
tf_model_path='frozen_inference_graph.pb',
mlmodel_path='my_model.mlmodel',
input_name_shape_dict={'ImageTensor:0':[1, 512, 512, 3]},
output_feature_names=['SemanticPredictions:0'],
add_custom_layers=True,
custom_conversion_functions={'Slice': _convert_slice}) # dictionary has op name as the key
print("\n \n ML Model layers info: \n")
# inspect the CoreML model: this should be same as the one we got above
spec = coreml_model.get_spec()
_print_coreml_nn_layer_info(spec)
我仍然遇到与以前相同的错误
Slice类型的OP缺少Shape Translator。
但我确实注意到我也收到此错误/警告
custom_conversion_functions = {'Slice':_convert_slice})#字典将操作名称作为关键字
我们将为您提供任何帮助