如何使用tensorflow数据集API多次迭代数据集

时间:2017-11-02 04:01:32

标签: tensorflow tensorflow-datasets

如何多次输出数据集中的值? (数据集由tensorflow的Dataset API创建)

import tensorflow as tf

dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
epoch = 10

for i in range(epoch):
   for j in range(100):
      value = sess.run(next_element)
      assert j == value
      print(j)

错误讯息:

tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

如何使这项工作?

4 个答案:

答案 0 :(得分:19)

首先,我建议你阅读Data Set Guide。描述了DataSet API的所有细节。

您的问题是多次迭代数据。以下是两种解决方案:

  1. 立刻迭代所有时代,没有关于个别时代结束的信息
  2. import tensorflow as tf
    
    epoch   = 10
    dataset = tf.data.Dataset.range(100)
    dataset = dataset.repeat(epoch)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    sess = tf.Session()
    
    num_batch = 0
    j = 0
    while True:
        try:
            value = sess.run(next_element)
            assert j == value
            j += 1
            num_batch += 1
            if j > 99: # new epoch
                j = 0
        except tf.errors.OutOfRangeError:
            break
    
    print ("Num Batch: ", num_batch)
    
    1. 第二个选项告诉您结束每个纪元,所以你可以。检查验证丢失:
    2. import tensorflow as tf
      
      epoch = 10
      dataset = tf.data.Dataset.range(100)
      iterator = dataset.make_initializable_iterator()
      next_element = iterator.get_next()
      sess = tf.Session()
      
      num_batch = 0
      
      for e in range(epoch):
          print ("Epoch: ", e)
          j = 0
          sess.run(iterator.initializer)
          while True:
              try:
                  value = sess.run(next_element)
                  assert j == value
                  j += 1
                  num_batch += 1
              except tf.errors.OutOfRangeError:
                  break
      
      print ("Num Batch: ", num_batch)
      

答案 1 :(得分:3)

如果您的tensorflow版本是1.3+,我推荐使用高级API tf.train.MonitoredTrainingSession。此API创建的sess可以使用tf.errors.OutOfRangeError自动检测sess.should_stop()。对于大多数培训情况,您需要随机播放数据并获取批处理,我已在以下代码中添加了这些内容。

import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=32)     # batch_size=1 if you want to get only one element per step
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

num_batch = 0
with tf.train.MonitoredTrainingSession() as sess:
    while not sess.should_stop():
        value = sess.run(next_element)
        num_batch += 1
        print("Num Batch: ", num_batch)

答案 2 :(得分:2)

试试这个

while True:
  try:
    print(sess.run(value))
  except tf.errors.OutOfRangeError:
    break

每当数据集迭代器到达数据末尾时,它都会引发tf.errors.OutOfRangeError,您可以使用except来捕获它,并从头开始数据集。

答案 3 :(得分:1)

类似于Toms的答案,对于tensorflow 2+,您可以使用以下高级API调用(在他的答案中建议的代码在tensorflow 2+中已弃用):

epoch = 10
batch_size = 32
dataset = tf.data.Dataset.range(100) 

dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.repeat(epoch)

num_batch = 0
for batch in dataset:
        num_batch += 1
        print("Num Batch: ", num_batch)

跟踪进度的有用电话是要迭代的批次总数(在batchrepeat调用之后 之后使用)

num_batches = tf.data.experimental.cardinality(dataset)

请注意,目前(tensorflow 2.1),cardinality方法仍处于实验阶段。