OutOfRangeError:tf.train.batch中的FIFOQueue'batch / fifo_queue'

时间:2017-08-18 14:35:08

标签: python tensorflow queue fifo tensorflow-gpu

我想解决这个问题。 我检查了num_epochs=None的{​​{1}},并将batch_size 64减少到16。 但是,OutOfRangeError不会消失。我尝试了一切(在stackoverflow上关于'OutOfRange,FIFOQueie,元素不足,train.batch,slice_input_producer'......如果你知道解决这个问题,请告诉我。

tf.train.slice_input_producer

'load_jpeg_with_tensorflow.py'在下面。

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os
import time
import re
from datetime import datetime

from datetime import timedelta
from nets import PAM_Cnn, FD_Cnn
import load_jpeg_with_tensorflow

flags = tf.app.flags
FLAGS = flags.FLAGS
FLAGS.height = 250
FLAGS.width = 250
FLAGS.num_classes = 2
FLAGS.batch_size = 16

######################################################## Load Data ########################################################
main_dir = './data/LFWdata/LFW_train/'
log_dir = 'tmp/PAM/'

num_classes = 2 # Number of bubbles x 2

# batch_img, batch_label = load_jpeg_with_tensorflow.read_data_batch(train_dir, 'trainImageList.csv', height, width,
#                                                                    num_channels, batch_size=batch_size)

######################################################## Placeholder Variable ########################################################

X = tf.placeholder(tf.float32, shape=[None, FLAGS.height, FLAGS.width, 3], name='Input')
Y = tf.placeholder(tf.float32, shape=[None, num_classes], name='Label')
Y_cls = tf.argmax(Y, dimension=1)

image_batch, label_batch, file_batch = load_jpeg_with_tensorflow.read_data_batch(main_dir+'trainImageList.csv', FLAGS.height, FLAGS.width, FLAGS.batch_size)

######################################################## Training Process ########################################################

keep_prob = tf.placeholder(tf.float32)
prediction = FD_Cnn.build_model(X, keep_prob)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=Y))
tf.summary.scalar('loss', loss)

optimizer = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(loss)

validate_image_batch, validate_label_batch, validate_file_batch = load_jpeg_with_tensorflow.read_data_batch(main_dir+'testImageList.csv', FLAGS.height, FLAGS.width, FLAGS.batch_size)
label_max = tf.argmax(Y, 1)
pre_max = tf.argmax(prediction, 1)
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
tf.summary.scalar('accuracy', accuracy)

startTime = datetime.now()
iteration = 20

summary = tf.summary.merge_all()
######################################################## TensorFlow Run ########################################################
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    sess.run(tf.initialize_all_variables())

    for i in range(iteration):
        images_, labels_ = sess.run([image_batch, label_batch])
        # images_ = image_batch
        # labels_ = label_batch
        sess.run(optimizer, feed_dict={X : images_, Y : labels_, keep_prob:0.5})

        if i % 10 == 0:
            now = datetime.now() - startTime
            print('## time:', now, ' steps:', i)

            rt = sess.run([label_max, pre_max, loss, accuracy], feed_dict={X : images_,
                                                                           Y : labels_,
                                                                           keep_prob : 1.0})
            print('Prediction loss:', rt[2], ' accuracy:', rt[3])
            # validation steps
            validate_images_, validate_labels_ = sess.run([validate_image_batch, validate_label_batch])
            rv = sess.run([label_max, pre_max, loss, accuracy], feed_dict={X: validate_images_,
                                                                           Y: validate_labels_,
                                                                           keep_prob: 1.0})
            print('Validation loss:', rv[2], ' accuracy:', rv[3])
            if (rv[3] > 0.9):
                break
            # validation accuracy
            summary_str = sess.run(summary, feed_dict={X: validate_images_,
                                                       Y: validate_labels_,
                                                       keep_prob: 1.0})
            summary_writer.add_summary(summary_str, i)
            summary_writer.flush()

        saver.save(sess, 'face_recog')  # save session
    coord.request_stop()
    coord.join(threads)
    print('finish')

0 个答案:

没有答案