如何在tensorflow attention_ocr上运行预训练模型?

时间:2017-08-07 16:29:18

标签: tensorflow inference pre-trained-model

我已经设法训练attention_ocr我的数据,我现在正在尝试进行推理运行(tensorflow版本1.2.1)。

我根据git README中提到的内容使用以下代码来使用预先训练的模型,但我总是得到一个重复字符列表,每次运行时都会发生变化(如[38,38,38 .. 。))。这显然是错误的,因为根据训练期间测试集的评估,我应该具有90%以上的角色准确度!

以前有人试过这个吗?或者有人可以提供一些修复它的提示吗?

images_placeholder = tf.placeholder(tf.float32, shape=[1, height, width, channels])
images_actual_data = cv2.imread(imageFname)
images_actual_data = cv2.cvtColor(images_actual_data, cv2.COLOR_BGR2RGB)

# some range normalization that is also done for training data
images_actual_data = images_actual_data.astype('float32')
images_actual_data -= images_actual_data.min()
images_actual_data /= images_actual_data.max()
images_actual_data -= 0.5
images_actual_data *= 2.5


model = common_flags.create_model(69,23,1,68) # based on the trained model
endpoints = model.create_base(images_placeholder, labels_one_hot=None)

with tf.Session() as sess:
    init_fn = model.create_init_fn_to_restore('/path-to-trained-models/model.ckpt-1126202', '')
    sess.run(tf.global_variables_initializer()) # tried to run sess.run(init_fn) here, but it fails
    predictions = sess.run(endpoints.predicted_chars, feed_dict={images_placeholder:images_actual_data.reshape(1,imHeight,imWidth,imChannel)})
    print predictions

2 个答案:

答案 0 :(得分:1)

我有点工作了。我在会话中没有正确运行。 无论如何,在运行预测之前添加以下行,解决了问题:

init_fn(sess)

它显然不是运行预训练模型的最佳方式(建议在git页面serving infrastructure上),但现在可以正常用于调试目的。

答案 1 :(得分:0)

我的猜测是范围标准化部分不正确。在培训模型uses tf.image.convert_image_dtype期间。所以请尝试替换:

images_actual_data -= images_actual_data.min()
images_actual_data /= images_actual_data.max()

images_actual_data /= 255.0