用tf.train.MonitoredTrainingSession编写多个摘要

时间:2018-08-31 12:57:21

标签: python tensorflow

我无法找到使用tf.train.MonitoredTrainingSession()编写总结的最优雅方式。我正在使用迭代器在TFRecordsDataSet上训练模型,并且一切正常,我的代码如下:

 # m is defined, takes an iterator
 with tf.train.MonitoredTrainingSession(checkpoint_dir=path) as sess:
    while not sess.should_stop():
        sess.run(m.train_op)

但是我添加了第二个迭代器,它遍历验证集。我不想为此训练模型,我只想检索验证数据的准确性和损失。这是我到目前为止的内容:

it_train_handle = training_dataset.make_one_shot_iterator().string_handle()
it_valid_handle = validation_dataset.make_one_shot_iterator().string_handle()

it_handle = tf.placeholder(tf.string, shape=[])

iterator = tf.data.Iterator.from_string_handle(it_handle,
    training_dataset.output_types, training_dataset.output_shapes)

next_element = iterator.get_next()

# model is defined, takes next_element as param 

with tf.train.MonitoredTrainingSession(checkpoint_dir=path) as sess:

    training_handle = sess.run(it_train_handle.string_handle())
    validation_handle = sess.run(it_valid_handle.string_handle())

    while not sess.should_stop():
        sess.run(m.train_op, feed_dict={it_handle: training_handle})
        sess.run([m.accuracy, m.loss], feed_dict={it_handle: validation_handle})

现在,我可能可以定义两个FileWriter并允许它们写入两个不同的文件,例如in this question,但我很确定有更好的方法(毕竟MonitoredTrainingSession具有默认的文件编写器我从未定义过,使用它的全部目的是它会自动执行此类操作,对吧?)

我知道tf.train.SummarySaverHook是一个问题,我想这是解决方案的一部分,但是我如何告诉会话使用哪个保护程序?

非常感谢您提供任何帮助。

0 个答案:

没有答案