如何在Tensorflow中通过feed_dict提供一个又一个图像

时间:2017-05-17 14:21:26

标签: python tensorflow

我想通过feed_dict一个接一个地提供图像。 我克隆了这个存储库https://github.com/affinelayer/pix2pix-tensorflow,在“pix2pix.py”中,图像由tf.WholeFileReader()从文件夹中提供。但我想从内存中提供加载的图像。我制作了一个numpy.ndarray图像列表,并在for循环中使用了feed_dict。但是在图像输入后,程序没有响应。 我更改了原始代码的“test”部分,这是我的代码。

    tf.reset_default_graph()
    init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    holder = tf.placeholder(dtype=tf.float32, shape=[None, None, 3])


    with tf.name_scope("load_images"):
        paths=[]
        paths.append(1)

        assertion = tf.assert_equal(tf.shape(holder)[2], 3, message="image does not have 3 channels")
        with tf.control_dependencies([assertion]):
            raw_input = tf.identity(holder)

        raw_input.set_shape([ None, None, 3])

        width = tf.shape(raw_input)[2] # [height, width, channels]

        #just copy for now
        a_images = cls.preprocess(raw_input[:,:,:])
        b_images = cls.preprocess(raw_input[:,:,:])

        inputs, targets = [b_images, a_images]

经过一些处理,以下是会话部分。 “arg_img”是一个参数,它是一个图像列表,converted_outputs是一个由pix2pix转换后的图像数组。

    saver = tf.train.Saver(max_to_keep=1)
    #logdir = cls.a.output_dir if (cls.a.trace_freq > 0 or cls.a.summary_freq > 0) else None
    #sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)

    with tf.Session() as sess:

        print("entered sess")

        if cls.a.checkpoint is not None:
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(cls.a.checkpoint)
            saver.restore(sess, checkpoint)
        print("checkpoint loaded")



        for j, img in enumerate(arg_img):
            im=np.asarray(img,dtype=np.float32)
            print(type(arg_img))

            ret_im=(sess.run(converted_outputs, feed_dict={holder:im}))

            print("passed")
            out_img_list.append(ret_im)

    return out_img_list

0 个答案:

没有答案