使用string_input_producer从TFRecord文件读取数据时出错

时间:2017-06-26 07:26:46

标签: tensorflow mnist

我编写了一个脚本来将MNIST数据更改为TFRecord格式:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

def _init64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

mnist = input_data.read_data_sets("/path/to/data", dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
num_examples = mnist.train.num_examples

num_shards = 10
instances_per_shard = int(num_examples / num_shards)

idx = 0
for i in range(num_shards):
    filename = '/tmp/mnist/tfrecord-%.2d' % i
    writer = tf.python_io.TFRecordWriter(filename)
    for j in range(instances_per_shard):
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': _bytes_feature(labels[idx].tostring()),
            'image_raw': _bytes_feature(images[idx].tostring())
        }))
        writer.write(example.SerializeToString())
        idx += 1
    writer.close()

然后从TFRecords文件中读取数据:

import tensorflow as tf

files = tf.train.match_filenames_once('/tmp/mnist/tfrecord-*')
filename_queue = tf.train.string_input_producer(files, shuffle=False)

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
    serialized_example,
    features={
        'image_raw': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.string)
    }
)

image = tf.decode_raw(features['image_raw'], tf.uint8)
decode_image = tf.reshape(image, [28, 28, 1])

label = features['label']
#label = tf.decode_raw(features['label'], tf.uint8)
#label = tf.reshape(label, [10])

batch_size = 4
capacity = 1000 + 3 * batch_size

example_batch, label_batch = tf.train.shuffle_batch([decode_image, label], batch_size=batch_size,
                                                    capacity=capacity, min_after_dequeue=30)

with tf.Session() as sess:
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(4):
        cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])
        print(cur_label_batch)

    coord.request_stop()
    coord.join(threads)

它运行良好。但如果我取消注释这两行:

label = tf.decode_raw(features['label'], tf.uint8)
label = tf.reshape(label, [10])

我收到以下错误:

Caused by op 'shuffle_batch', defined at:
  File "/home/chenk/workspace/tflearn/Learning/create_batch.py", line 27, in <module>
  capacity=capacity, min_after_dequeue=30)
  File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 1217, in shuffle_batch
name=name)
  File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 788, in _shuffle_batch
  dequeued = queue.dequeue_many(batch_size, name=name)
  File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/data_flow_ops.py", line 457, in dequeue_many
  self._queue_ref, n=n, component_types=self._dtypes, name=name)
  File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 946, in _queue_dequeue_many_v2
timeout_ms=timeout_ms, name=name)
  File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
  File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2506, in create_op
original_op=self._default_original_op, op_def=op_def)
  File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1269, in __init__
self._traceback = _extract_stack()

OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 4, current size 0)
 [[Node: shuffle_batch = QueueDequeueManyV2[component_types=[DT_UINT8, DT_UINT8], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](shuffle_batch/random_shuffle_queue, shuffle_batch/n)]]

我的代码中有什么问题吗?这样做的正确方法是什么?

谢谢!

1 个答案:

答案 0 :(得分:0)

mnist图片位于uint8,但标签的类型为float64。当您将tfrecords写为to_string()时,每个float64值将转换为8个字节。因此,当您阅读tfrecords时,您应该将其读作tf.float64。将其读作uint8将生成80个标签,错误实际上是由reshape()函数引起的。

label = tf.decode_raw(features['label'], tf.float64)
label = tf.reshape(label, [10])