无法读取TensorFlow上的数据

时间:2016-09-04 16:41:03

标签: python tensorflow

在此之前,我将输入图像转换为TFRecords文件。现在我有以下方法,我主要从教程中收集并修改了一点:

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        # Defaults are not specified since both keys are required.
        features={
            'image/encoded': tf.FixedLenFeature([], tf.string),
            'image/class/label': tf.FixedLenFeature([], tf.int64),
        })
    image = tf.decode_raw(features['image/encoded'], tf.uint8)
    label = tf.cast(features['image/class/label'], tf.int32)

    reshaped_image = tf.reshape(image,[size[0], size[1], 3])
    reshaped_image = tf.image.resize_images(reshaped_image, size[0], size[1], method = 0)
    reshaped_image = tf.image.per_image_whitening(reshaped_image)
    return reshaped_image, label

def inputs(train, batch_size, num_epochs):
    filename = os.path.join(FLAGS.train_dir,
                          TRAIN_FILE if train else VALIDATION_FILE)

    filename_queue = tf.train.string_input_producer(
        [filename], num_epochs=num_epochs)

    # Even when reading in multiple threads, share the filename
    # queue.
    image, label = read_and_decode(filename_queue)

    # Shuffle the examples and collect them into batch_size batches.
    # (Internally uses a RandomShuffleQueue.)
    # We run this in two threads to avoid being a bottleneck.
    images, sparse_labels = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, num_threads=2,
        capacity=1000 + 3 * batch_size,
        # Ensures a minimum amount of shuffling of examples.
        min_after_dequeue=1000)
    return images, sparse_labels

但是当我尝试在iPython / Jupyter上调用批处理时,进程永远不会结束(似乎有一个循环)。我这样称呼它:

batch_x, batch_y = inputs(True, 100,1)
print batch_x.eval()

1 个答案:

答案 0 :(得分:1)

看起来你错过了对tf.train.start_queue_runners()的调用,它启动了驱动输入管道的后台线程(例如,其中一些是num_threads=2在{{1}调用中隐含的线程并且tf.train.string_input_producer()也需要后台线程)。以下小改动应取消阻止:

tf.train.shuffle_batch()