我是TensorFlow的绝对新手。
如果我想要使用Cifar10 TensorFlow教程中的代码尝试分类图片(或图片集),我该怎么做?
我完全不知道从哪里开始。
答案 0 :(得分:0)
tf.placeholder
并按以下方式提供数据,但还有很多其他方法。feed_dict
,请使用placeholder
运行会话。
import tensorflow as tf
train_dir = '/tmp/cifar10_train' # or use FLAGS as in the train example
batch_size = 8
height = 32
width = 32
image = tf.placeholder(shape=(batch_size, height, width, 3), dtype=tf.uint8)
std_img = tf.image.per_image_standardization(image)
logits = cifar10.inference(std_img)
predictions = tf.argmax(logits, axis=-1)
def get_image_data_batches():
n_batchs = 100
for i in range(n_batchs):
yield (np.random.uniform(size=(batch_size, height, width, 3)*255).astype(np.uint8)
def do_stuff_with(logit_vals, prediction_vals):
pass
with tf.Session() as sess:
# restore variables
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(train_dir))
# run inference
for batch_data in get_image_data_batches():
logit_vals, prediction_vals = sess.run([logits, predictions], feed_dict={image: image_data})
do_stuff_with(logit_vals, prediction_vals)
有更好的方法可以将数据添加到图表中(请参阅tf.data.Dataset
),但我相信tf.placeholder
是最简单的学习方法,并且最初可以运行并运行。
同时查看tf.estimator.Estimator
以获得更简洁的会话管理方式。它与本教程中的方式完全不同,而且灵活性稍差,但对于标准网络,它们可以节省您编写大量样板代码的时间。