我正在构建一个分布式张量流模型,我对如何以干净的方式使用tf.MonitoredTrainingSession感到有些困惑。
这是我的培训代码:
#Define number of training steps
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.nb_train_step)]
with tf.train.MonitoredTrainingSession(master=target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir=FLAGS.logs_dir,
hooks = hooks) as sess:
while not sess.should_stop():
batch_train = gen_train.next() #training data generator
feed_dict = {X: batch_train[0],
Y: batch_train[1]}
variables = [loss, merged_summary, train_step]
current_loss, summary, _ = sess.run(variables, feed_dict)
print("Batch loss: %s" % current_loss)
现在,如果我想在每n
个培训步骤中丢失模型验证,我可以添加该块以便每n
个步骤进行评估:
batch_val = gen_val.next() #validation data generator
feed_dict = {X: batch_train[0],
Y: batch_train[1]}
val_loss = sess.run([loss],feed_dict)
但这会增加我钩子中的步骤数,这意味着验证损失计算将被视为训练步骤。有干净的方法吗?我误解了钩子的作用吗?