如何将TFRecords转换为numpy数组?

时间:2016-03-16 04:37:37

标签: tensorflow

主要思想是将TFRecords转换为numpy数组。假设TFRecord存储图像。具体做法是:

  1. 读取TFRecord文件并将每个图像转换为numpy数组。
  2. 将图像写入1.jpg,2.jpg等
  3. 同时,将文件名和标签写入文本文件,如下所示:
    1.jpg 2
    2.jpg 4
    3.jpg 5
    
  4. 我目前使用以下代码:

    import tensorflow as tf
    import os
    
    def read_and_decode(filename_queue):
      reader = tf.TFRecordReader()
      _, serialized_example = reader.read(filename_queue)
      features = tf.parse_single_example(
          serialized_example,
          # Defaults are not specified since both keys are required.
          features={
              'image_raw': tf.FixedLenFeature([], tf.string),
              'label': tf.FixedLenFeature([], tf.int64),
              'height': tf.FixedLenFeature([], tf.int64),
              'width': tf.FixedLenFeature([], tf.int64),
              'depth': tf.FixedLenFeature([], tf.int64)
          })
      image = tf.decode_raw(features['image_raw'], tf.uint8)
      label = tf.cast(features['label'], tf.int32)
      height = tf.cast(features['height'], tf.int32)
      width = tf.cast(features['width'], tf.int32)
      depth = tf.cast(features['depth'], tf.int32)
      return image, label, height, width, depth
    
    with tf.Session() as sess:
      filename_queue = tf.train.string_input_producer(["../data/svhn/svhn_train.tfrecords"])
      image, label, height, width, depth = read_and_decode(filename_queue)
      image = tf.reshape(image, tf.pack([height, width, 3]))
      image.set_shape([32,32,3])
      init_op = tf.initialize_all_variables()
      sess.run(init_op)
      print (image.eval())
    

    我正在阅读尝试为初学者获取至少一张图片。当我运行时,代码就会卡住。

1 个答案:

答案 0 :(得分:17)

哎呀,这对我来说是个愚蠢的错误。我使用了string_input_producer但忘了运行queue_runners。

with tf.Session() as sess:
  filename_queue = tf.train.string_input_producer(["../data/svhn/svhn_train.tfrecords"])
  image, label, height, width, depth = read_and_decode(filename_queue)
  image = tf.reshape(image, tf.pack([height, width, 3]))
  image.set_shape([32,32,3])
  init_op = tf.initialize_all_variables()
  sess.run(init_op)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)
  for i in range(1000):
    example, l = sess.run([image, label])
    print (example,l)
  coord.request_stop()
  coord.join(threads)