使用TFrecords时如何设置纪元?

时间:2018-12-06 11:39:29

标签: tensorflow

我正在使用TFrecords读取数据集,代码如下所示:

filename_queue = tf.train.string_input_producer(tfrecords_path)
    reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        'face_gt': tf.FixedLenFeature([], tf.string),
        'mfcc_gt': tf.FixedLenFeature([], tf.string),
        'identity5': tf.FixedLenFeature([], tf.string)
    })


......


face_gt_batch, mfcc_gt_batch, identity5_batch = tf.train.batch(
        [face_gt, mfcc_gt,  identity5], batch_size=batch_size, num_threads=64, capacity=2000)

训练时,我有一个step_nums指示要训练的步骤。

现在,我想设置一个epoch_nums来指示要训练多少个时期。

但是我不知道TFrecords中到底有多少条记录。因此无法使用epoch_nums=data_size/batch_size

我该怎么做?在哪里设置epoch_nums

这是我的训练代码:

for step in np.arange(step_nums):
    if coord.should_stop():
        break

    ... sess.run() ...

    if step % 10000 == 0 or (step+1) == step_nums:
        saver.save(...)

谢谢!

0 个答案:

没有答案