我无法找到使用tf.train.MonitoredTrainingSession()编写总结的最优雅方式。我正在使用迭代器在TFRecordsDataSet上训练模型,并且一切正常,我的代码如下:
# m is defined, takes an iterator
with tf.train.MonitoredTrainingSession(checkpoint_dir=path) as sess:
while not sess.should_stop():
sess.run(m.train_op)
但是我添加了第二个迭代器,它遍历验证集。我不想为此训练模型,我只想检索验证数据的准确性和损失。这是我到目前为止的内容:
it_train_handle = training_dataset.make_one_shot_iterator().string_handle()
it_valid_handle = validation_dataset.make_one_shot_iterator().string_handle()
it_handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(it_handle,
training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# model is defined, takes next_element as param
with tf.train.MonitoredTrainingSession(checkpoint_dir=path) as sess:
training_handle = sess.run(it_train_handle.string_handle())
validation_handle = sess.run(it_valid_handle.string_handle())
while not sess.should_stop():
sess.run(m.train_op, feed_dict={it_handle: training_handle})
sess.run([m.accuracy, m.loss], feed_dict={it_handle: validation_handle})
现在,我可能可以定义两个FileWriter并允许它们写入两个不同的文件,例如in this question,但我很确定有更好的方法(毕竟MonitoredTrainingSession具有默认的文件编写器我从未定义过,使用它的全部目的是它会自动执行此类操作,对吧?)
我知道tf.train.SummarySaverHook是一个问题,我想这是解决方案的一部分,但是我如何告诉会话使用哪个保护程序?
非常感谢您提供任何帮助。