我有一个训练数据集和一个测试数据集,
#training dataset
dataset_train = tf.data.TFRecordDataset(files_train)
dataset_train = dataset_train.map(...)
dataset_train = dataset_train.shuffle(...)
dataset_train = dataset_train.batch(...)
dataset_train = dataset_train.repeat(1)
iterator_train = dataset_train.make_initializable_iterator()
#test dataset
dataset_test = tf.data.TFRecordDataset(files_test)
dataset_test = dataset_test.map(...)
dataset_test = dataset_test.shuffle(...)
dataset_test = dataset_test.batch(...)
dataset_test = dataset_test.repeat(...)
iterator_test = dataset_test.make_initializable_iterator()
#for switch between two datasets.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types, dataset_train.output_shapes)
image_batch, label_batch = iterator.get_next()
在会议上,我有:
# in tf.Session()
train_iterator_handle = sess.run(train_iterator.string_handle())
val_iterator_handle = sess.run(test_iterator.string_handle())
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
#start training, switch to training dataset
sess.run(iterator_train.initializer)
while True:
try:
sess.run([train_step, ...])
if global_step % N == 0: # test
#start test, switch to test dataset
sess.run(iterator_test.initializer)
while True:
try:
sess.run([acc_update, ...])
except tf.errors.OutOfRangeError:
print("test finished")
break
#test finished, switch back to training dataset
sess.run(iterator_train.initializer)
except tf.errors.OutOfRangeError:
print("training finished")
break
我从TF的API中读到,训练数据集迭代器可以从上次离开的地方继续,并且我认为训练数据集在迭代所有数据时应该停止,因为我使用了:
dataset_train = dataset_train.repeat(1)
但是实际上,我的程序正在运行,并且不会停止。 所以我想我一定在某个地方犯了个严重的错误。有人能帮我吗?
答案 0 :(得分:1)
验证sess.run(iterator_train.initializer)
之后的这一行将重置火车发电机的状态,因此它将从头开始继续读取。我想,N
比火车迭代器中的步数少,所以它不会停止
如果您只想在验证后继续训练,请不要再次调用训练迭代器初始化程序