我编写了一个脚本来将MNIST数据更改为TFRecord格式:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
def _init64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
mnist = input_data.read_data_sets("/path/to/data", dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
num_examples = mnist.train.num_examples
num_shards = 10
instances_per_shard = int(num_examples / num_shards)
idx = 0
for i in range(num_shards):
filename = '/tmp/mnist/tfrecord-%.2d' % i
writer = tf.python_io.TFRecordWriter(filename)
for j in range(instances_per_shard):
example = tf.train.Example(features=tf.train.Features(feature={
'label': _bytes_feature(labels[idx].tostring()),
'image_raw': _bytes_feature(images[idx].tostring())
}))
writer.write(example.SerializeToString())
idx += 1
writer.close()
然后从TFRecords文件中读取数据:
import tensorflow as tf
files = tf.train.match_filenames_once('/tmp/mnist/tfrecord-*')
filename_queue = tf.train.string_input_producer(files, shuffle=False)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.string)
}
)
image = tf.decode_raw(features['image_raw'], tf.uint8)
decode_image = tf.reshape(image, [28, 28, 1])
label = features['label']
#label = tf.decode_raw(features['label'], tf.uint8)
#label = tf.reshape(label, [10])
batch_size = 4
capacity = 1000 + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch([decode_image, label], batch_size=batch_size,
capacity=capacity, min_after_dequeue=30)
with tf.Session() as sess:
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(4):
cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])
print(cur_label_batch)
coord.request_stop()
coord.join(threads)
它运行良好。但如果我取消注释这两行:
label = tf.decode_raw(features['label'], tf.uint8)
label = tf.reshape(label, [10])
我收到以下错误:
Caused by op 'shuffle_batch', defined at:
File "/home/chenk/workspace/tflearn/Learning/create_batch.py", line 27, in <module>
capacity=capacity, min_after_dequeue=30)
File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 1217, in shuffle_batch
name=name)
File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 788, in _shuffle_batch
dequeued = queue.dequeue_many(batch_size, name=name)
File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/data_flow_ops.py", line 457, in dequeue_many
self._queue_ref, n=n, component_types=self._dtypes, name=name)
File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 946, in _queue_dequeue_many_v2
timeout_ms=timeout_ms, name=name)
File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2506, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1269, in __init__
self._traceback = _extract_stack()
OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 4, current size 0)
[[Node: shuffle_batch = QueueDequeueManyV2[component_types=[DT_UINT8, DT_UINT8], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](shuffle_batch/random_shuffle_queue, shuffle_batch/n)]]
我的代码中有什么问题吗?这样做的正确方法是什么?
谢谢!
答案 0 :(得分:0)
mnist图片位于uint8
,但标签的类型为float64
。当您将tfrecords写为to_string()
时,每个float64值将转换为8个字节。因此,当您阅读tfrecords时,您应该将其读作tf.float64
。将其读作uint8
将生成80个标签,错误实际上是由reshape()
函数引起的。
label = tf.decode_raw(features['label'], tf.float64)
label = tf.reshape(label, [10])