使用TFRecordDataset时如何设置纪元计数器?

时间:2019-01-22 04:16:12

标签: python tensorflow

我正在使用tf.data.TFRecordDataset从TFRecord文件中读取数据集。

我正试图弄清楚每个步骤正在处理哪个纪元。

我尝试了Epoch counter with TensorFlow Dataset API的答案,但似乎对我不起作用。

详细信息:TFRecord文件中保存了100个样本,batch_size设置为50,epoch_num设置为5。

这是我的简化代码:

def read_and_decode_TFRecordDataset(tfrecords_path, batch_size, epoch_num):
    dataset = tf.data.TFRecordDataset(tfrecords_path)
    dataset = dataset.map(parser_deblur)
    epoch = tf.data.Dataset.range(epoch_num)
    dataset = epoch.flat_map(lambda i: tf.data.Dataset.zip(
        (dataset, tf.data.Dataset.from_tensors(i).repeat())))
    dataset = dataset.repeat(epoch_num).shuffle(1000).batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    (face_blur_batch, face_gt_batch), epochNow = iterator.get_next()
    return face_blur_batch, face_gt_batch, epochNow
print EPOCH: {epochNow}, STEP: {step}

我期望的是:

EPOCH: [0 0 0 ... 0 0] (fifty zero), STEP: 1
EPOCH: [0 0 0 ... 0 0] (fifty zero), STEP: 2 
EPOCH: [1 1 1 ... 1 1] (fifty one), STEP: 3
EPOCH: [1 1 1 ... 1 1] (fifty one), STEP: 4
EPOCH: [2 2 2 ... 2 2] (fifty two), STEP: 5
EPOCH: [2 2 2 ... 2 2] (fifty two), STEP: 6
EPOCH: [3 3 3 ... 3 3] (fifty three), STEP: 7
EPOCH: [3 3 3 ... 3 3] (fifty three), STEP: 8
EPOCH: [4 4 4 ... 4 4] (fifty four), STEP: 9
EPOCH: [4 4 4 ... 4 4] (fifty four), STEP: 10

但是实际输出是:

EPOCH: [2 0 4 ... 4 1] , STEP: 1
EPOCH: [4 0 2 ... 3 4] , STEP: 2 
EPOCH: [4 0 3 ... 2 2] , STEP: 3
EPOCH: [1 1 3 ... 1 3] , STEP: 4
EPOCH: [1 4 0 ... 0 1] , STEP: 5
EPOCH: [0 4 4 ... 4 3] , STEP: 6
EPOCH: [3 1 0 ... 3 2] , STEP: 7
EPOCH: [4 2 4 ... 3 1] , STEP: 8
EPOCH: [0 0 1 ... 3 3] , STEP: 9
EPOCH: [3 1 3 ... 3 2] , STEP: 10

我不知道输出的EPOCH是什么?似乎是随机的。而且每次运行都不同。

任何想法如何解决以上代码?还是如何通过其他方式获取纪元计数器?

1 个答案:

答案 0 :(得分:0)

我已经解决了这个问题。问题是,我在分配EPOCH之后改组了。正确的顺序应首先洗牌:

def read_and_decode_TFRecordDataset(tfrecords_path, batch_size, epoch_num):
    dataset = tf.data.TFRecordDataset(tfrecords_path)
    dataset = dataset.map(parser_deblur).shuffle(buffer_size=100*batch_size)
    epoch = tf.data.Dataset.range(epoch_num)
    dataset = epoch.flat_map(lambda i: tf.data.Dataset.zip(
        (dataset, tf.data.Dataset.from_tensors(i).repeat())))
    dataset = dataset.repeat(epoch_num).batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    (face_blur_batch, face_gt_batch), epochNow = iterator.get_next()
    return face_blur_batch, face_gt_batch, epochNow
print EPOCH: {epochNow}, STEP: {step}