读取TFRecordDataset时出现OutOfRangeError

时间:2019-02-23 02:15:25

标签: python tensorflow

我正在使用

tf.train.Feature(float_list=tf.train.FloatList(value=v))

为我的tf.train.Example创建TFRecordDataset个示例。生成数据集的工作正在寻找中,但是每当我尝试使用它进行训练时,我都会得到一个OutOfRangeError

这是我创建数据集的方式:

for agent in agents_provider.agents:

    input_dim = 12 * 8 * 8
    output_dim = agent.encoder_decoder.encoding_size

    def map_example(proto):
        keys_to_features = {
            'inputs': tf.FixedLenFeature(shape=(input_dim, ), dtype=tf.float32),
            'targets': tf.FixedLenFeature(shape=(output_dim, ), dtype=tf.float32)
        }
        parsed = tf.parse_single_example(proto, keys_to_features)
        return parsed['inputs'], parsed['targets']

    dataset = (
        tf.data.TFRecordDataset(glob.glob(os.path.join(shards_directory, '*.train.shard')))
            .map(map_example, num_parallel_calls=cpu_count())
            # .repeat()
            .batch(1)
            .make_one_shot_iterator()
    )

但运行它:

nxt = dataset.get_next()

session = tf.Session()
session.run(tf.global_variables_initializer())
result = session.run(nxt)

会把这个扔给我

Caused by op 'IteratorGetNext', defined at:
  File "/opt/pycharm/pycharm-2018.1.3/helpers/pydev/pydevd.py", line 1664, in <module>
    main()
  File "/opt/pycharm/pycharm-2018.1.3/helpers/pydev/pydevd.py", line 1658, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/opt/pycharm/pycharm-2018.1.3/helpers/pydev/pydevd.py", line 1068, in run
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/opt/pycharm/pycharm-2018.1.3/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/media/Data/workspaces/git/chessai/chessai/selfplay.py", line 195, in <module>
    main()
  File "/media/Data/workspaces/git/chessai/chessai/selfplay.py", line 171, in main
    inputs, targets = dataset.get_next()
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 421, in get_next
    name=name)), self._output_types,
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 2069, in iterator_get_next
    output_shapes=output_shapes, name=name)
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3274, in create_op
    op_def=op_def)
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1770, in __init__
    self._traceback = tf_stack.extract_stack()

OutOfRangeError (see above for traceback): End of sequence
     [[node IteratorGetNext (defined at /media/Data/workspaces/git/chessai/chessai/selfplay.py:171)  = IteratorGetNext[output_shapes=[[?,768], [?,16]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

我只是不明白问题是什么..


map_example()更改为

def map_example(proto):
    features = {
        'inputs': tf.FixedLenFeature([], dtype=tf.float32),
        'targets': tf.FixedLenFeature([], dtype=tf.float32)
    }
    parsed = tf.parse_single_example(proto, features)
    return parsed['inputs'], parsed['targets']

将给出:

OutOfRangeError (see above for traceback): End of sequence
     [[node IteratorGetNext (defined at /media/Data/workspaces/git/chessai/chessai/selfplay.py:170)  = IteratorGetNext[output_shapes=[[?], [?]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

1 个答案:

答案 0 :(得分:0)

事实证明:

import tensorflow as tf


def main():

    dataset = tf.data.TFRecordDataset([])

    iterator = dataset.make_one_shot_iterator()

    get_next = iterator.get_next()

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        batch = session.run(get_next)
        print(batch)


if __name__ == '__main__':
    main()

会给你这个错误:

Caused by op 'IteratorGetNext', defined at:
  File "/opt/pycharm/pycharm-2018.1.3/helpers/pydev/pydevd.py", line 1664, in <module>
    main()
  File "/opt/pycharm/pycharm-2018.1.3/helpers/pydev/pydevd.py", line 1658, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/opt/pycharm/pycharm-2018.1.3/helpers/pydev/pydevd.py", line 1068, in run
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/opt/pycharm/pycharm-2018.1.3/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/media/Data/workspaces/git/chessai/chessai/tmp.py", line 19, in <module>
    main()
  File "/media/Data/workspaces/git/chessai/chessai/tmp.py", line 10, in main
    get_next = iterator.get_next()
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 421, in get_next
    name=name)), self._output_types,
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 2069, in iterator_get_next
    output_shapes=output_shapes, name=name)
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3274, in create_op
    op_def=op_def)
  File "/home/stefan/miniconda3/envs/chessai/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1770, in __init__
    self._traceback = tf_stack.extract_stack()

OutOfRangeError (see above for traceback): End of sequence
     [[node IteratorGetNext (defined at /media/Data/workspaces/git/chessai/chessai/tmp.py:10)  = IteratorGetNext[output_shapes=[[]], output_types=[DT_STRING], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]