Tensorflow - 使用批处理进行预测

时间:2016-09-05 06:21:40

标签: python machine-learning tensorflow

我尝试使用经过训练的卷积神经网络进行预测,稍微修改了示例专家tensorflow教程中的示例。我按照https://www.tensorflow.org/versions/master/how_tos/reading_data/index.html的说明从CSV文件中读取数据。

我训练了模型并评估了它的准确性。然后我保存了模型并将其加载到一个新的python脚本中进行预测。我仍然可以使用上面链接中详述的批处理方法,还是应该使用feed_dict?我在网上看过的大多数教程都使用后者。

我的代码如下所示,我基本上复制了从我的训练数据中读取的代码,该代码作为行存储在单个.csv文件中。 Conv_nn只是一个包含专家MNIST教程中详述的卷积神经网络的类。除了我运行图形的部分之外,大多数内容可能不是很有用。

我怀疑我已经严重混淆了训练和预测 - 我不确定测试图像是否正确地被送入预测操作,或者是否对两个数据集都使用相同的批处理操作是有效的。

filename_queue =  tf.train.string_input_producer(["data/test.csv"],num_epochs=None)

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Defaults force key value and label to int, all others to float.
record_defaults = [[1]]+[[46]]+[[1.0] for i in range(436)]
# Reads in a single row from the CSV and outputs a list of scalars.
csv_list = tf.decode_csv(value, record_defaults=record_defaults)
# Packs the different columns into separate feature tensors.
location = tf.pack(csv_list[2:4])
bbox = tf.pack(csv_list[5:8])
pix_feats = tf.pack(csv_list[9:])
onehot = tf.one_hot(csv_list[1], depth=98)
keep_prob = 0.5


# Creates batches of images and labels.
image_batch, label_batch = tf.train.shuffle_batch(
    [pix_feats, onehot],
    batch_size=50,num_threads=4,capacity=50000,min_after_dequeue=10000)

# Creates a graph of variables and operation nodes.
nn = Conv_nn(x=image_batch,keep_prob=keep_prob,pixels=33*13,outputs=98)

# Launch the default graph.
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    saver.restore(sess, 'model1.ckpt')
    print("Model restored.")

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)

    prediction=tf.argmax(nn.y_conv,1)

    pred = sess.run([prediction])

    coord.request_stop()
    coord.join(threads)

1 个答案:

答案 0 :(得分:0)

这个问题已经过时了,但无论如何我都会回答,因为它已经被观看了近1000次。

因此,如果您的模型有 Y 标签和 X 输入,那么

prediction=tf.argmax(Y,1)
result = prediction.eval(feed_dict={X: [data]}, session=sess)

这会评估单个输入,例如单个mnist图像,但它可以是批处理。