将两个队列传递给Tensorflow培训

时间:2017-03-29 21:56:06

标签: tensorflow

我正在尝试使用来自Tensorflow的CIFAR10示例创建一个使用tf.RandomShuffleQueue的列车操作,我的标签来自(Accessing filename from file queue in Tensor Flow)中提到的文件名。我该如何使用此代码?

当我尝试运行以下代码时,path是一个包含许多文件的目录:

filenames = [path, f) for f in os.listdir(path)][1:]
file_fifo = tf.train.string_input_producer(filenames,
                                           shuffle=False,
                                           capacity=len(filenames))
reader = tf.WholeFileReader()
key, value = reader.read(file_fifo)
image = tf.image.decode_png(value, channels=3, dtype=tf.uint8)
image.set_shape([config.image_height, config.image_width, config.image_depth])
image = tf.cast(image, tf.float32)
image = tf.divide(image, 255.0)
labels = [int(os.path.basename(f).split('_')[-1].split('.')[0]) for f in filenames]
label_fifo = tf.FIFOQueue(len(filenames), tf.int32, shapes=[[]])
label_enqueue = label_fifo.enqueue_many([tf.constant(labels)])
label = label_fifo.dequeue()
bq = tf.RandomShuffleQueue(capacity=16 * batch_size,
                           min_after_dequeue=8 * batch,
                           dtypes=[tf.float32, tf.int32])
batch_enqueue_op = bq.enqueue([image, label_enqueue])
runner = tf.train.queue_runner.QueueRunner(bq, [batch_enqueue_op] * num_threads)
tf.train.add_queue_runner(runner)

# Read 'batch' labels + images from the example queue.
images, labels = batch_queue.dequeue_many(FLAGS.batch_size)
labels = tf.reshape(labels, [FLAGS.batch_size, 1])

我得到明显的错误,因为我知道我的代码没有多大意义。我有两个FIFO队列file_fifolabel_fifo,但我不知道如何对我的tf.RandomShuffleQueue进行label_fifo输入。

有人可以帮忙吗?谢谢: - )

1 个答案:

答案 0 :(得分:0)

我将代码更改为:

filenames = [os.path.join(FLAGS.data_path, f) for f in os.listdir(FLAGS.data_path)][1:]
np.random.shuffle(filenames)
file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
reader = tf.WholeFileReader()
key, value = reader.read(file_fifo)
image = tf.image.decode_png(value, channels=3, dtype=tf.uint8)
image.set_shape([config.image_height, config.image_width, config.image_depth])
image = tf.cast(image, tf.float32)
image = tf.divide(image, 255.0)

labels = [int(os.path.basename(f).split('_')[-1].split('.')[0]) for f in filenames]
label_fifo = tf.FIFOQueue(len(filenames), tf.int32, shapes=[[]])
label_enqueue = label_fifo.enqueue_many([tf.constant(labels)])
label = label_fifo.dequeue()

if is_train:
    images, label_batch = tf.train.shuffle_batch([image, label],
                                                 batch_size=FLAGS.batch_size,
                                                 num_threads=FLAGS.num_threads,
                                                 capacity=16 * FLAGS.batch_size,
                                                 min_after_dequeue=8 * FLAGS.batch_size)
labels = tf.reshape(label_batch, [FLAGS.batch_size, 1])

对于培训,我有:

class _LoggerHook(tf.train.SessionRunHook):
    """Logs loss and runtime."""

    def begin(self):
        self._step = -1

    def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        if self._step % int(config.train_examples / FLAGS.batch_size) == 0 or self._step == 0:
            run_context.session.run(label_enqueue_op)
        return tf.train.SessionRunArgs({'loss': loss, 'net': net})

我将训练视为:

with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_path,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()],
        config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(train_op)

训练开始,但它只在第一步运行并挂起 - 也许是因为它正在等待一些队列命令