这是我从tfrecords读取数据的代码:tfrecords包含10,000个样本
def read_notmnist(filename_queue):
"""
Read examples from notmnist tfrecords data
:param filename_queue: a Queue of strings with the filenames to read from
:return:
image: a [IMAGE_SIZE,IMAGE_SIZE,1] uint8 tensor represent image data.
label: a int32 Tensor with the label in the range 0 - (NUM_CLASSES-1)
"""
reader = tf.TFRecordReader()
_, serialize_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialize_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string)
}
)
image = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 1])
label = tf.cast(features["label"], tf.int32)
# pdb.set_trace()
# label = tf.reshape(label,[])
return image, label
我写了一个测试程序:
def main(unused_args):
image,label = test_read_notmnist(data_dir=FLAGS.data_dir,batch_size=FLAGS.train_batch_size)
print("start run")
sess = tf.Session()
tf.train.start_queue_runners(sess=sess)
sess.run(tf.global_variables_initializer())
all_pass = True
for i in range(100000000):
# sess = tf_debug.LocalCLIDebugWrapperSession(sess)
image_run,label_run = sess.run([image,label])
all_pass = all_pass and (image_run.shape == (FLAGS.train_batch_size, 28, 28, 1) ) and \
(label_run.shape == (FLAGS.train_batch_size,))
# print("Sample Processed {0}".format((i+1)*(FLAGS.train_batch_size)))
print("Sample Processed {0}".format((i+1)))
for循环可以运行100000000次而不会出现任何错误。我没有找到任何参数来自定义此行为,如果我只想要所有数据一次,我应该怎么做(在验证方案中,我只想推断验证数据集中的每个示例一次)