TensorFlow:使用受监控的培训会话

时间:2017-11-09 02:48:08

标签: python tensorflow deep-learning

我正在使用数据集API导入培训和验证数据。我有TF 1.2。因此,我只能使用可重新初始化的迭代器,并且不能使用可馈送迭代器,因为可馈送迭代器只能从TF 1.4中使用。

1)如果我们想培训网络,我们可以简单地使用受监控的培训课程。但是,当我们想要在培训时验证我们应该如何做到这一点?我们应该转储受监控的培训课程并使用低级别会话吗?

train_dataset = tf.contrib.data.TFRecordDataset([FLAGS.data_dir + "train.tfrecords"])
train_dataset = train_dataset.map(_parse_records)
train_dataset.shuffle(buffer_size=1000)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(FLAGS.batch_size)

validation_dataset = tf.contrib.data.TFRecordDataset([FLAGS.data_dir + "test.tfrecords"])
validation_dataset = test_dataset.map(_parse_records)
validation_dataset = test_dataset.batch(FLAGS.batch_size)

iterator = tf.contrib.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

train_init_op = iterator.make_initializer(train_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)

training_op = tf.train.AdagradOptimizer(...).minimize(loss)

with tf.train.MonitoredTrainingSession(...) as sess:
    sess.run(train_init_op)
    while not sess.should_stop():
        sess.run(training_op)

   # HOW TO VALIDATE?

2)有没有办法使用Reinitializable迭代器在一个纪元的中间验证模型,因为它需要在我们在迭代器之间切换时从数据集的开头初始化迭代器。可以使用Reinitializable迭代器,或者我们必须切换到可输入迭代器来执行此操作吗?

这是TF数据集教程中提供的示例。如果在一个纪元中有100次迭代,我们可以使用Reinitializable迭代器在迭代50验证模型吗? (我认为可以使用可馈送的迭代器)

# Run 20 epochs in which the training dataset is traversed, followed by the validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
    sess.run(training_init_op)
    for _ in range(100):
        sess.run(next_element)

# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
    sess.run(next_element)

3)在使用可重新初始化的迭代器时,在时间的最后一次迭代中,如果剩余的训练数据样本小于所需的批量大小,会发生什么? 剩下的少量样品是否会以减少的批量大小使用,否则会被忽略?

2 个答案:

答案 0 :(得分:2)

对于你的问题3,TensorFlow表现不佳,我想。对于最后一批,它可能具有较少数量的样本。这经常(总是?)在训练期间导致“不兼容的形状”错误。请参阅https://stackoverflow.com/a/48331954/2184122了解如何解决TensorFlow 1.4的问题

答案 1 :(得分:1)

请看一下How to switch between training and validation dataset with tf.MonitoredTrainingSession? 我想你会找到1)和2)的答案。 您可以使用feed_dict更改要评估的数据集,或者只是重新初始化它。从链接:

...
training_iterator = training_ds.make_initializable_iterator()
validation_iterator = validation_ds.make_initializable_iterator()
...
sess.run(next_element, feed_dict={handle: training_handle})
...
sess.run(next_element, feed_dict={handle: validation_iterator })