使用tf.train.MonitoredTrainingSession获取验证丢失的简洁方法是什么?

时间:2017-06-24 23:48:06

标签: tensorflow distributed

我正在构建一个分布式张量流模型,我对如何以干净的方式使用tf.MonitoredTrainingSession感到有些困惑。

这是我的培训代码:

#Define number of training steps
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.nb_train_step)]

with tf.train.MonitoredTrainingSession(master=target,
    is_chief=(FLAGS.task_index == 0),
    checkpoint_dir=FLAGS.logs_dir,
    hooks = hooks) as sess:

    while not sess.should_stop():
        batch_train = gen_train.next() #training data generator

        feed_dict = {X: batch_train[0],
                        Y: batch_train[1]}

        variables = [loss, merged_summary, train_step]
        current_loss, summary,  _ = sess.run(variables, feed_dict)
        print("Batch loss: %s" % current_loss)

现在,如果我想在每n个培训步骤中丢失模型验证,我可以添加该块以便每n个步骤进行评估:

batch_val = gen_val.next() #validation data generator
feed_dict = {X: batch_train[0],
            Y: batch_train[1]}

val_loss = sess.run([loss],feed_dict)

但这会增加我钩子中的步骤数,这意味着验证损失计算将被视为训练步骤。有干净的方法吗?我误解了钩子的作用吗?

0 个答案:

没有答案