我正在使用TFrecords读取数据集,代码如下所示:
filename_queue = tf.train.string_input_producer(tfrecords_path)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features={
'face_gt': tf.FixedLenFeature([], tf.string),
'mfcc_gt': tf.FixedLenFeature([], tf.string),
'identity5': tf.FixedLenFeature([], tf.string)
})
......
face_gt_batch, mfcc_gt_batch, identity5_batch = tf.train.batch(
[face_gt, mfcc_gt, identity5], batch_size=batch_size, num_threads=64, capacity=2000)
训练时,我有一个step_nums
指示要训练的步骤。
现在,我想设置一个epoch_nums
来指示要训练多少个时期。
但是我不知道TFrecords中到底有多少条记录。因此无法使用epoch_nums=data_size/batch_size
。
我该怎么做?在哪里设置epoch_nums
?
这是我的训练代码:
for step in np.arange(step_nums):
if coord.should_stop():
break
... sess.run() ...
if step % 10000 == 0 or (step+1) == step_nums:
saver.save(...)
谢谢!