在iOS上实现TensorFlow注意OCR

时间:2017-07-08 19:38:35

标签: c++ ios tensorflow ocr tensorflow-serving

我已成功训练(使用初始V3权重作为初始化)此处描述的注意OCR模型:https://github.com/tensorflow/models/tree/master/attention_ocr并将生成的检查点文件冻结为图形。如何使用iOS上的C ++ API实现此网络?

提前谢谢。

1 个答案:

答案 0 :(得分:2)

根据其他人的建议,您可以使用一些现有的iOS演示(12)作为起点,但请密切注意以下细节:

  1. 确保使用正确的工具“冻结”模型。 SavedModel是Tensorflow模型的通用序列化格式。
  2. 模型导出脚本可以并且通常执行某种输入规范化。请注意,Model.create_base函数需要tf.float32形状的张量[batch_size,height,width,channels],其值将标准化为[-1.25, 1.25]。如果您将图像规范化作为TensorFlow计算图的一部分进行,请确保图像传递非标准化,反之亦然。
  3. 要获取输入/输出张量的名称,您只需打印它们,例如导出脚本中的某个位置:

    data_images = tf.placeholder(dtype=tf.float32, shape=[batch_size, height, width, channels], name='normalized_input_images')
    endpoints = model.create_base(data_images, labels_one_hot=None)
    print(data_images, endpoints.predicted_chars, endpoints.predicted_scores)