使用CIFAR-10示例的单个图像的Tensorflow推理

时间:2016-07-04 05:55:55

标签: tensorflow

我正在尝试使用tensorflow cifar10示例来推断单个图像: https://www.tensorflow.org/versions/r0.8/tutorials/deep_cnn/index.html#convolutional-neural-networks

def restore_vars(saver, sess):
        """ Restore saved net, global score and step, and epsilons OR
        create checkpoint directory for later storage. """
        #sess.run(tf.initialize_all_variables())

        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
          # Restores from checkpoint
          saver.restore(sess, ckpt.model_checkpoint_path)
          return True
        else:
          print('No checkpoint file found')
          return False


    def eval_single_img():
        input_img = tf.image.decode_jpeg(tf.read_file("test.jpg"), channels=3)
        input_img = 
        input_img = tf.reshape(input_img, [3, 32, 32])
        input_img = tf.transpose(input_img, [1, 2, 0])
        reshaped_image = tf.cast(input_img, tf.float32)

        resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, 24, 24)

        float_image = tf.image.per_image_whitening(resized_image)

        image = tf.expand_dims(float_image, 0)  # create a fake batch of images (batch_size = 1)


        logits = cifar10.inference(image)

        _, top_k_pred = tf.nn.top_k(logits, k=5)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
             cifar10.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        with tf.Session() as sess:
            restored = restore_vars(saver, sess)

            top_indices = sess.run([top_k_pred])
            print ("Predicted ", top_indices[0], " for your input image.")

**错误消息:         tensorflow.python.framework.errors.InvalidArgumentError:Assign要求两个张量的形状匹配。 lhs shape = [18,384] rhs shape = [2304,384]                  [[节点:save / Assign_5 =分配[T = DT_FLOAT,_class = [“loc:@ local3 / weights”],use_locking = true,validate_shape = true,_device =“/ job:localhost / replica:0 / task:0 / cpu:0“](local3 / weights,save / restore_slice_5)]]         由op u'save / Assign_5'引起,定义于:

    What might be causing this?**

0 个答案:

没有答案