在Tensorflow Estimators中使用TFRecords时,是否有一种简单的方法来设置时期

时间:2019-03-23 10:44:53

标签: tensorflow tfrecord

将numpy数组输入估算器时,有一种设置历元的好方法

  tf.estimator.inputs.numpy_input_fn(
     x,
     y=None,
     batch_size=128,
     num_epochs=1 ,
     shuffle=None,
     queue_capacity=1000,
     num_threads=1  
   )

但是我无法使用TFRecords来追踪类似的方法,大多数人似乎只是将其陷入循环

 i = 0 
 while ( i < 100000):
   model.train(input_fn=input_fn, steps=100)

是否有一种干净的方法来使用估计器显式设置TFRecords的时期数?

1 个答案:

答案 0 :(得分:1)

您可以使用dataset.repeat(num_epochs)设置纪元数。数据集管道输出输入到model.train()

的数据集对象,即批处理大小的元组(功能,标签)。
dataset = tf.data.TFRecordDataset(file.tfrecords)
dataset = tf.shuffle().repeat()
...
dataset = dataset.batch()

为了使其工作,请设置model.train(steps=None, max_steps=None)。在这种情况下,一旦达到num_epoch,就让Dataset API通过生成tf.errors.OutOfRange错误或StopIteration异常来处理时代计数。