如何使用tfslim记录验证丢失和准确性

时间:2018-04-18 17:34:33

标签: tensorflow tensorboard tf-slim

使用tf-slim时,有什么方法可以将验证丢失和准确性记录到tensorboard?当我使用keras时,以下代码可以为我做到这一点:

model.fit_generator(generator=train_gen(), validation_data=valid_gen(),...)

然后模型将评估每个时期后的验证损失和准确性,这非常方便。但是如何使用tf-slim实现这一目标?以下步骤使用原始张量流,这不是我想要的:

with tf.Session() as sess:
    for step in range(100000):
        sess.run(train_op, feed_dict={X: X_train, y: y_train})
        if n % batch_size * batches_per_epoch == 0:
            print(sess.run(train_op, feed_dict={X: X_train, y: y_train}))

目前,使用tf-slim训练模型的步骤是:

tf.contrib.slim.learning.train(
    train_op=train_op,
    logdir="logs",
    number_of_steps=10000,
    log_every_n_steps = 10,
    save_summaries_secs=1
)

那么如何使用上述细长的培训程序评估每个时期后的验证损失和准确性?

提前致谢!

1 个答案:

答案 0 :(得分:1)

此问题仍在讨论TF Slim repo(issue #5987)。 该框架允许您轻松创建评估脚本,以便在培训之后/之后运行(下面的解决方案1),但是有些人正在努力实现“批量培训+验证的经典循环”(解决方案2)。

1。在另一个脚本中使用slim.evaluation

TF Slim有评估方法,例如: slim.evaluation.evaluation_loop()您可以在另一个脚本(可以与您的训练并行运行)中使用,以定期加载模型的最新检查点并执行评估。 TF Slim页面包含一个很好的示例,说明此类脚本的外观:example

2。向train_step_fn

提供自定义slim.learning.train()

讨论的发起人提出的一个不完整的解决方案利用了您可以提供给slim.learning.train()的自定义训练步骤功能:

"""
Snippet from code by Kevin Malakoff @kmalakoff
https://github.com/tensorflow/tensorflow/issues/5987#issue-192626454
"""
# ...
accuracy_validation = slim.metrics.accuracy(
    tf.argmax(predictions_validation, 1), 
    tf.argmax(labels_validation, 1)) # ... or whatever metrics needed

def train_step_fn(session, *args, **kwargs):
  total_loss, should_stop = train_step(session, *args, **kwargs)

  if train_step_fn.step % FLAGS.validation_check == 0:
    accuracy = session.run(train_step_fn.accuracy_validation)
    print('Step %s - Loss: %.2f Accuracy: %.2f%%' % (str(train_step_fn.step).rjust(6, '0'), total_loss, accuracy * 100))

  # ...

  train_step_fn.step += 1
  return [total_loss, should_stop]

train_step_fn.step = 0
train_step_fn.accuracy_validation = accuracy_validation

slim.learning.train(
  train_op,
  FLAGS.logs_dir,
  train_step_fn=train_step_fn,
  graph=graph,
  number_of_steps=FLAGS.max_steps
)