将numpy数组输入估算器时,有一种设置历元的好方法
tf.estimator.inputs.numpy_input_fn(
x,
y=None,
batch_size=128,
num_epochs=1 ,
shuffle=None,
queue_capacity=1000,
num_threads=1
)
但是我无法使用TFRecords来追踪类似的方法,大多数人似乎只是将其陷入循环
i = 0
while ( i < 100000):
model.train(input_fn=input_fn, steps=100)
是否有一种干净的方法来使用估计器显式设置TFRecords的时期数?
答案 0 :(得分:1)
您可以使用dataset.repeat(num_epochs)
设置纪元数。数据集管道输出输入到model.train()
dataset = tf.data.TFRecordDataset(file.tfrecords)
dataset = tf.shuffle().repeat()
...
dataset = dataset.batch()
为了使其工作,请设置model.train(steps=None, max_steps=None)
。在这种情况下,一旦达到num_epoch,就让Dataset API通过生成tf.errors.OutOfRange
错误或StopIteration
异常来处理时代计数。