为什么即使设置了dataset.repeat(1)我的数据集也不会停止

时间:2018-08-17 08:13:26

标签: python tensorflow iterator dataset

我有一个训练数据集和一个测试数据集,

#training dataset
dataset_train = tf.data.TFRecordDataset(files_train)
dataset_train = dataset_train.map(...)
dataset_train = dataset_train.shuffle(...)
dataset_train = dataset_train.batch(...)
dataset_train = dataset_train.repeat(1)
iterator_train = dataset_train.make_initializable_iterator()

#test dataset
dataset_test = tf.data.TFRecordDataset(files_test)
dataset_test = dataset_test.map(...)
dataset_test = dataset_test.shuffle(...)
dataset_test = dataset_test.batch(...)
dataset_test = dataset_test.repeat(...)
iterator_test = dataset_test.make_initializable_iterator()

#for switch between two datasets.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types, dataset_train.output_shapes)
image_batch, label_batch = iterator.get_next()

在会议上,我有:

# in tf.Session()
train_iterator_handle = sess.run(train_iterator.string_handle())
val_iterator_handle = sess.run(test_iterator.string_handle())
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

#start training, switch to training dataset
sess.run(iterator_train.initializer) 
while True:
    try:
        sess.run([train_step, ...])

        if global_step % N == 0: # test
            #start test, switch to test dataset
            sess.run(iterator_test.initializer)
            while True:
                try:
                    sess.run([acc_update, ...])
                except tf.errors.OutOfRangeError:
                    print("test finished")
                    break
            #test finished, switch back to training dataset
            sess.run(iterator_train.initializer) 
    except tf.errors.OutOfRangeError:
        print("training finished")
        break

我从TF的API中读到,训练数据集迭代器可以从上次离开的地方继续,并且我认为训练数据集在迭代所有数据时应该停止,因为我使用了:

dataset_train = dataset_train.repeat(1)

但是实际上,我的程序正在运行,并且不会停止。 所以我想我一定在某个地方犯了个严重的错误。有人能帮我吗?

1 个答案:

答案 0 :(得分:1)

验证sess.run(iterator_train.initializer)之后的这一行将重置火车发电机的状态,因此它将从头开始继续读取。我想,N比火车迭代器中的步数少,所以它不会停止

如果您只想在验证后继续训练,请不要再次调用训练迭代器初始化程序