在Tensorflow中提供可变宽度/高度图像以进行推理

时间:2017-08-03 08:14:05

标签: tensorflow

我跟着this恢复我的模型并对单个图像进行推理,我的神经网络是完全卷积的,因此它可以处理可变宽度/高度的图像。 我在训练时使用以下代码。请注意我用于训练的图像是64宽度/高度,并由队列运行器提取。

feed_input = tf.placeholder(tf.float32, (1, None, None, 3), name=FEED_INPUT_PLACEHOLDER_NAME)
temp_feed_input = np.ndarray(shape=(1, 1, 1, 3), dtype=float)
...
sess.run([train_op], feed_dict={is_using_feed_input:False, feed_input:temp_feed_input})

我想对不同大小的图像进行推理,比如147x256,我使用下面的代码对单个图像进行推理,它引发了异常,说输入输入[1,147,256,3]与[不兼容] 1,?,?,3]。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from fuzhi_constants import *

def inference(model_dir, ldr_image_path):
    import os
    import cv2
    import time
    import numpy as np
    import tensorflow as tf
    regression_path = os.path.join(os.path.dirname(ldr_image_path), 'regression.exr')
    with tf.Graph().as_default():
        with tf.Session() as sess:
            ldr_img = cv2.imread(ldr_image_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
            ldr_img = np.expand_dims(ldr_img, axis=0)
            ckpt = tf.train.get_checkpoint_state(model_dir)
            saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path+".meta")
            saver.restore(sess, ckpt.model_checkpoint_path)
            # start queue runner
            # see e.g. https://www.tensorflow.org/programmers_guide/reading_data
            coord = tf.train.Coordinator()
            threads = []
            for qr in sess.graph.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                             start=True))
            start_time = time.time()
            regression = sess.run('conv_out/sub:0', feed_dict={IS_USING_FEED_INPUT_PLACEHOLDER_NAME+':0':True, FEED_INPUT_PLACEHOLDER_NAME+':0':ldr_img})
            elapsed_time = time.time() - start_time
            print('elapsed time: %s seconds' % (elapsed_time))
            cv2.imwrite(regression_path, regression[0])

if __name__ == '__main__':
    inference('G:\\neural_network\\optimizer=Adadelta', 'G:\\input_ldr.png')

0 个答案:

没有答案