tf.TFRecordReader永远不会用尽

时间:2017-09-17 08:07:44

标签: python tensorflow

这是我从tfrecords读取数据的代码:tfrecords包含10,000个样本

def read_notmnist(filename_queue):
  """
  Read examples from notmnist tfrecords data
  :param filename_queue: a Queue of strings with the filenames to read from
  :return:
  image: a [IMAGE_SIZE,IMAGE_SIZE,1] uint8 tensor represent image data.
  label: a int32 Tensor with the label in the range 0 - (NUM_CLASSES-1)
  """
  reader = tf.TFRecordReader()
  _, serialize_example = reader.read(filename_queue)

  features = tf.parse_single_example(
    serialize_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'image_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 1])

  label = tf.cast(features["label"], tf.int32)
  # pdb.set_trace()
  # label = tf.reshape(label,[])
  return image, label

我写了一个测试程序:

def main(unused_args):
    image,label = test_read_notmnist(data_dir=FLAGS.data_dir,batch_size=FLAGS.train_batch_size)

    print("start run")
    sess = tf.Session()
    tf.train.start_queue_runners(sess=sess)
    sess.run(tf.global_variables_initializer())
    all_pass = True
    for i in range(100000000):
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        image_run,label_run = sess.run([image,label])
        all_pass = all_pass and (image_run.shape == (FLAGS.train_batch_size, 28, 28, 1) ) and \
                   (label_run.shape == (FLAGS.train_batch_size,))
        # print("Sample Processed {0}".format((i+1)*(FLAGS.train_batch_size)))
        print("Sample Processed {0}".format((i+1)))

for循环可以运行100000000次而不会出现任何错误。我没有找到任何参数来自定义此行为,如果我只想要所有数据一次,我应该怎么做(在验证方案中,我只想推断验证数据集中的每个示例一次)

0 个答案:

没有答案