Tensorflow关注OCR推理代码

时间:2017-07-17 19:10:24

标签: tensorflow attention-model

我试图在张量流模型https://github.com/tensorflow/models/tree/master/attention_ocr中运行注意ocr。我可以找到用于训练和评估FSNS数据集的脚本,但是他们没有代码来对单个图像进行推理。我想用我的图像测试它,看看它有多好。 这是我为推理部分尝试但我得到错误“尝试使用未初始化的值AttentionOcr_v1 / conv_tower_fn / INCE / InceptionV3 / Conv2d_1a_3x3 /权重”

      dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
      model = common_flags.create_model(dataset.num_char_classes,
                                        dataset.max_sequence_length,
                                        1, dataset.null_code)

      #images: A tensor of shape [batch_size, height, width, channels].
      images_actual_data = Image.open("imagename.jpg").resize((600, 150))
      #Increase dimension of data and make it 4D
      images_actual_data = np.expand_dims(images_actual_data,axis=0)/255.0
      slim.get_or_create_global_step()  
      master_checkpoint = "trained_files/attention_ocr_2017_05_17/model.ckpt-399731"
      inception_checkpoint = "inception-v3_2016_08_28/inception_v3.ckpt"

#inference code
      with tf.Session() as sess:
        init_op = tf.initialize_all_variables()
        sess.run(init_op)
        images_placeholder = tf.placeholder(dtype=tf.float32,shape=(1,150,600,3))
        endpoints = model.create_base(images_placeholder, labels_one_hot=None)
        model.create_init_fn_to_restore(master_checkpoint, inception_checkpoint)
        predictions = sess.run(endpoints.predicted_chars, feed_dict={images_placeholder: images_actual_data})
  

这是输出日志

    DEBUG 2017-08-01 21:20:58.000825: model.py: 314 images: Tensor("Placeholder:0", shape=(1, 150, 600, 3), dtype=float32)
    DEBUG 2017-08-01 21:20:58.000827: model.py: 319 Views=1 single view: Tensor("AttentionOcr_v1/split:0", shape=(1, 150, 600, 3), dtype=float32)
    DEBUG 2017-08-01 21:20:58.000827: model.py: 186 Using final_endpoint=Mixed_5d
    DEBUG 2017-08-01 21:20:59.000906: model.py: 325 Conv tower: Tensor("AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Mixed_5d/concat:0", shape=(1, 16, 72, 288), dtype=float32)
    DEBUG 2017-08-01 21:20:59.000909: model.py: 328 Pooled views: Tensor("AttentionOcr_v1/pool_views_fn/STCK/Reshape:0", shape=(1, 1152, 288), dtype=float32)
    DEBUG 2017-08-01 21:20:59.000909: sequence_layers.py: 421 Use AttentionWithAutoregression as a layer class
    DEBUG 2017-08-01 21:21:01.000588: model.py: 331 chars_logit: Tensor("AttentionOcr_v1/sequence_logit_fn/SQLR/concat:0", shape=(1, 37, 134), dtype=float32)
    INFO 2017-08-01 21:21:01.000864: model.py: 511 Request to re-store 117 weights from trained_files/attention_ocr_2017_05_17/model.ckpt-399731
    INFO 2017-08-01 21:21:02.000021: model.py: 511 Request to re-store 104 weights from inception-v3_2016_08_28/inception_v3.ckpt

    Traceback (most recent call last):
      File "/Users/sxa091/everything/example_projects/tf_projects/models/attention_ocr/python/eval_image.py", line 84, in <module>
        app.run()
      File "/usr/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
        _sys.exit(main(_sys.argv[:1] + flags_passthrough))
      File "/Users/sxa091/everything/example_projects/tf_projects/models/attention_ocr/python/eval_image.py", line 80, in main
        predictions = sess.run(endpoints.predicted_chars, feed_dict={images_placeholder: images_actual_data})
      File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 778, in run
        run_metadata_ptr)
      File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 982, in _run
        feed_dict_string, options, run_metadata)
      File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1032, in _do_run
        target_list, options, run_metadata)
      File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1052, in _do_call
        raise type(e)(node_def, op, message)
    tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_1a_3x3/weights
         [[Node: AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_1a_3x3/weights/read = Identity[T=DT_FLOAT, _class=["loc:@AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_1a_3x3/weights"], _device="/job:localhost/replica:0/task:0/cpu:0"](AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_1a_3x3/weights)]]

1 个答案:

答案 0 :(得分:0)

每个人都有自己喜欢的存储和阅读图片的方式,所以这个片段是故意不完整的。

以下是填充images_actual_data的一种可能方法:

import skimage.io as io
import numpy as np

fn = '600x150_image_with_4_views_%i.jpg'
images = [io.imread(fn % i, dtype='float') for i in range(batch_size)]
images_actual_data = np.stack(images)
images_actual_data = 2.5*(images_actual_data - 0.5)  # normalize values

请注意,您需要使用相同类型的normalization which was used during training