我有一个简单的MNIST数据分类模型,准确率在92%左右。
我想知道是否有任何方法可以提供带有数字的图像并将标签作为该数字的输出?图像可以来自mnist测试数据,而不是自定义图像,只是为了避免图像预处理?下面是我的模型的代码。
由于
import tensorflow as tf
#reset graph
tf.reset_default_graph()
#constants
learning_rate = 0.5
batch_size = 100
training_epochs = 5
logs_path = "/tmp/mnist/2"
#load mnist data set
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
with tf.name_scope('inputs'):
x = tf.placeholder(tf.float32, shape=[None,784], name = "image-input")
y_= tf.placeholder(tf.float32, shape=[None, 10], name = "labels-input")
#weights
with tf.name_scope("weights"):
W = tf.Variable(tf.zeros([784,10]))
#biases
with tf.name_scope("biases"):
b= tf.Variable(tf.zeros([10]))
#Activation function softmax
with tf.name_scope("softmax"):
#y is prediction
y = tf.nn.softmax(tf.matmul(x,W) +b)
#Cost function
with tf.name_scope('cross_entropy'):
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1])) #????
#Define Optimizer
with tf.name_scope('train'):
train_optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
#Accuracy
with tf.name_scope('Accuracy'):
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
tf.summary.scalar("cost",cross_entropy)
tf.summary.scalar("accuracy",accuracy)
#Merge all summaries into a single "operation" which will be executed in a session
summary_op = tf.summary.merge_all()
with tf.Session() as sess:
#initialize variables before using them
sess.run(tf.global_variables_initializer())
#log writer object
# writer = tf.train.SummaryWriter(logs_path, graph=tf.get_default_graph())
writer = tf.summary.FileWriter(logs_path,graph=tf.get_default_graph())
#training cycles
for epoch in range(training_epochs):
#number of batches in one epoch
batch_count = int(mnist.train.num_examples/batch_size)
for i in range(batch_count):
batch_x, batch_y = mnist.train.next_batch(batch_size)
_,summary = sess.run([train_optimizer,summary_op], feed_dict={x: batch_x, y_:batch_y})
writer.add_summary(summary,epoch * batch_count + i)
if epoch % 5 == 0:
print("Epoch: ",epoch)
print("Accuracy: ",accuracy.eval(feed_dict={x: mnist.test.images,y_:mnist.test.labels}))
print("Done")
答案 0 :(得分:2)
训练网络后,您可以通过
获取网络为新图像提供的标签new_image_label= sess.run(y, feed_dict={x: new_image})
请注意,new_image
的格式应与batch_x
的格式相同。将new_image
视为批量1的批次,因此如果batch_x
为2D,则new_image
也应为2D(形状1乘784)。
此外,如果您对batch_x
中的图片进行了一些预处理(例如规范化),则需要对new_image
执行相同操作。
您还可以使用与上面相同的代码同时获取多个图像的标签。只需将new_image
替换为多个图像new_images
的二维数组。