如何使用Tensorflow Cifar10教程代码进行推理?

时间:2018-05-23 21:04:09

标签: python tensorflow

我是TensorFlow的绝对新手。

如果我想要使用Cifar10 TensorFlow教程中的代码尝试分类图片(或图片集),我该怎么做?

我完全不知道从哪里开始。

1 个答案:

答案 0 :(得分:0)

  1. 完全根据教程使用基本CIFAR10数据集训练模型。
  2. 使用您自己的输入创建一个新图表 - 最简单的方法是使用tf.placeholder并按以下方式提供数据,但还有很多其他方法。
  3. 开始会话,加载以前保存的权重。
  4. 如果您使用上述feed_dict,请使用placeholder运行会话。
  5. import tensorflow as tf
    
    train_dir = '/tmp/cifar10_train'  # or use FLAGS as in the train example
    batch_size = 8
    height = 32
    width = 32
    
    image = tf.placeholder(shape=(batch_size, height, width, 3), dtype=tf.uint8)
    std_img = tf.image.per_image_standardization(image)
    logits = cifar10.inference(std_img)
    predictions = tf.argmax(logits, axis=-1)
    
    def get_image_data_batches():
        n_batchs = 100
        for i in range(n_batchs):
            yield (np.random.uniform(size=(batch_size, height, width, 3)*255).astype(np.uint8)
    
    def do_stuff_with(logit_vals, prediction_vals):
        pass
    
    with tf.Session() as sess:
        # restore variables
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(train_dir))
        # run inference
        for batch_data in get_image_data_batches():
            logit_vals, prediction_vals = sess.run([logits, predictions], feed_dict={image: image_data})
            do_stuff_with(logit_vals, prediction_vals)
    

    有更好的方法可以将数据添加到图表中(请参阅tf.data.Dataset),但我相信tf.placeholder是最简单的学习方法,并且最初可以运行并运行。

    同时查看tf.estimator.Estimator以获得更简洁的会话管理方式。它与本教程中的方式完全不同,而且灵活性稍差,但对于标准网络,它们可以节省您编写大量样板代码的时间。