dataset.repeat()在TensorFlow中不起作用

时间:2018-07-18 15:49:38

标签: python tensorflow

这是代码的一部分

def train(x):
    prediction = cnn(x)
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction, labels=y))
    optimizer = tf.train.AdadeltaOptimizer().minimize(cost)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in xrange(num_epochs):
            epoch_loss = 0
            for _ in xrange(batch_size):
                _, c = sess.run([optimizer, cost])
                epoch_loss += c


            print('Epoch {} completed out of {} - loss {}'.format(epoch + 1, num_epochs, epoch_loss))


n_classes = 17
batch_size = 32
dropout_rate = 0.4
num_epochs = 10

train_set = read_image_dataset_tfrecordfile('train.tfrecord', resize=True)

train_set = train_set.batch(batch_size)
train_set.repeat(num_epochs)
train_iterator = train_set.make_one_shot_iterator()

x, y = train_iterator.get_next()

train(x)

当我运行它时,它只会执行第一个时期,然后抛出OutOfRangeError,这里是堆栈

Epoch 1 completed out of 10 - loss 5.82853866496e+11


Traceback (most recent call last):
  File "/Users/user/PycharmProjects/ProveTF/main.py", line 113, in <module>
    train(x)
  File "/Users/user/PycharmProjects/ProveTF/main.py", line 83, in train
    _, c = sess.run([optimizer, cost])
  File "/Users/user/venv/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 905, in run
    run_metadata_ptr)
  File "/Users/user/venv/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1137, in _run
    feed_dict_tensor, options, run_metadata)
  File "/Users/user/venv/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1355, in _do_run
    options, run_metadata)
  File "/Users/user/venv/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1374, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,100,100,1], [?,17]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

Caused by op u'IteratorGetNext', defined at:
  File "/Users/user/PycharmProjects/ProveTF/main.py", line 110, in <module>
    x, y = train_iterator.get_next()
  File "/Users/user/venv/lib/python2.7/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 330, in get_next
    name=name)), self._output_types,
  File "/Users/user/venv/lib/python2.7/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 866, in iterator_get_next
    output_shapes=output_shapes, name=name)
  File "/Users/user/venv/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Users/user/venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3271, in create_op
    op_def=op_def)
  File "/Users/user/venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1650, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

OutOfRangeError (see above for traceback): End of sequence
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,100,100,1], [?,17]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

我试图将repeat()方法移到其他地方,并且尝试编写不带参数的简单repeat(),但是它仍然无法正常工作。

有什么解决方案或建议吗?

1 个答案:

答案 0 :(得分:1)

您需要像使用train_set = train_set.repeat()方法一样分配batch。它不会修改数据集。