从tfrecord文件创建批处理时的OutOfRangeError

时间:2018-04-09 14:47:55

标签: tensorflow tfrecord

我正在编写一个脚本,将我的数据的某些功能保存到tfrecord。这些功能是numpy数组(float32)。当我读取tfrecord文件时,我收到以下错误:

OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 20, current size 0)
 [[Node: shuffle_batch = QueueDequeueManyV2[component_types=[DT_UINT8, DT_UINT8], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch/random_shuffle_queue, shuffle_batch/n)]]

我搜索了很多,显然这个错误可能是由不同的事情引起的。到目前为止,我无法修复它。我使用以下最小代码重新创建了问题:

  1. 保存玩具数据:

    def _bytes_feature(value):
       return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    writer = tf.python_io.TFRecordWriter('stuff.tfrecords')
    
    for i in range(100):
    
        seq = np.random.uniform(size=(500,300)).astype(np.float32)
        lbl = np.random.uniform(size=(90,1)).astype(np.float32)
    
        feature = {'train/lbl': _bytes_feature(tf.compat.as_bytes(lbl.tostring())),
           'train/seq': _bytes_feature(tf.compat.as_bytes(seq.tostring()))}
    
        example = tf.train.Example(features=tf.train.Features(feature=feature))
    
        writer.write(example.SerializeToString())
    
    writer.close()
    sys.stdout.flush()
    
  2. 阅读数据:

    def read_and_decode_single_example(filename):
    
        filename_queue = tf.train.string_input_producer([filename], num_epochs=1)   
        reader = tf.TFRecordReader() 
        _, serialized_example = reader.read(filename_queue)
        f = {'train/lbl': tf.FixedLenFeature([], tf.string),
       'train/seq': tf.FixedLenFeature([], tf.string)}
    
        features = tf.parse_single_example(serialized_example, features=f)
    
        seq = tf.decode_raw(features['train/seq'], tf.float32)
        lbl = tf.decode_raw(features['train/lbl'], tf.float32)
    
        seq = tf.reshape(seq, [ 500,300 ])
        lbl = tf.reshape(lbl, [ 90,1 ])
    
        sbatch, lbatch = tf.train.shuffle_batch([seq, lbl],
                                      batch_size= batch_size,
                                      capacity=3*batch_size,
                                      min_after_dequeue=batch_size)
    
        return sbatch, lbatch 
    
    
    sbatch, lbatch = read_and_decode_single_example("stuff.tfrecords" )
    
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
    
        s,l = sess.run([sbatch, lbatch])
    
        coord.request_stop()
        coord.join(threads)
    
  3. 我正在使用Tensorflow-GPU v.1.4.0。 这是一些错误代码,可能提供信息:

    Caused by op 'shuffle_batch', defined at:
     File "teststuff.py", line 59, in <module>
    sbatch, lbatch = read_and_decode_single_example("stuff.tfrecords" )
     File "teststuff.py", line 54, in read_and_decode_single_example
    min_after_dequeue=batch_size)
     File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/input.py", line 1225, in shuffle_batch
    name=name)
     File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/input.py", line 796, in _shuffle_batch
    dequeued = queue.dequeue_many(batch_size, name=name)
     File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/data_flow_ops.py", line 464, in dequeue_many
    self._queue_ref, n=n, component_types=self._dtypes, name=name)
     File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 2418, in _queue_dequeue_many_v2
    component_types=component_types, timeout_ms=timeout_ms, name=name)
     File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
     File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
     File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access
    

0 个答案:

没有答案