我正在查看此答案,以便在培训期间运行评估指标:
How to use evaluation_loop with train_loop in tf-slim
似乎重写train_step_fn=train_step_fn
是合理的做法。但我想运行验证循环,而不是评估。我的图表是这样的:
with tf.Graph().as_default():
train_dataset = slim.dataset.Dataset(data_sources= "train_*.tfrecord")
train_images, _, train_labels = load_batch(train_dataset,
batch_size=mini_batch_size,
is_training=True)
val_dataset = slim.dataset.Dataset(data_sources= "validation_*.tfrecord")
val_images, _, val_labels = load_batch(val_dataset,
batch_size=mini_batch_size,
is_training=False)
with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=0.0005)):
net, end_points = vgg.vgg_16(train_images,
num_classes=10,
is_training=is_training)
predictions = tf.nn.softmax(net)
labels = train_labels
...
init_fn = slim.assign_from_checkpoint_fn(
checkpoint_path,
slim.get_variables_to_restore(exclude=['vgg_16/fc8']),
ignore_missing_vars=True
)
final_loss = slim.learning.train(train_op, TRAIN_LOG,
train_step_fn=train_step_fn,
init_fn=init_fn,
global_step=global_step,
number_of_steps=steps,
save_summaries_secs=60,
save_interval_secs=600,
session_config=sess_config,
)
我想添加类似这样的内容,以针对网络的当前权重进行小批量验证循环
def validate_on_checkpoint(sess, *args, **kwargs ):
loss,mean,stddev = sess.run([val_loss, val_rms_mean, val_rms_stddev],
feed_dict={images: val_images,
labels: val_labels,
is_training: is_training })
validation_writer = tf.train.SummaryWriter(LOG_DIR + '/validation')
validation_writer.add_summary(loss, global_step)
validation_writer.add_summary(mean, global_step)
validation_writer.add_summary(stddev, global_step)
def train_step_fn(sess, *args, **kwargs):
total_loss, should_stop = train_step(sess, *args, **kwargs)
if train_step_fn.step % FLAGS.validation_every_n_step == 0:
validate_on_checkpoint(sess, *args, **kwargs )
train_step_fn.step += 1
return [total_loss, should_stop]
但我收到了错误= Graph is finalized and cannot be modified.
从概念上讲,我不确定应该如何添加它。 training
循环需要网络的渐变,丢失和权重更新,但validation
循环会跳过所有这些。如果我尝试修改图表,我会继续在Graph is finalized and cannot be modified.
上获得变体;如果我使用XXX is not defined
方法,我会if is_training: else:
答案 0 :(得分:1)
我想出了一种方法,可以从其他一些stackoverflow答案中完成这项工作。以下是基础知识:
1)获取train
和validation
数据集的输入和标签
x_train, y_train = produce_batch(320)
x_validation, y_validation = produce_batch(320)
2)使用reuse=True
重用train
和validation
循环之间的模型权重。这是一种方式:
with tf.variable_scope("model") as scope:
# Make the model, reuse weights for validation batches
predictions, nodes = regression_model(inputs, is_training=True)
scope.reuse_variables()
val_predictions, _ = regression_model(val_inputs, is_training=False)
3)定义您的损失,将您的validation
损失归入不同的集合中,以免train
tf.losses.get_losses()
损失增加
loss = tf.losses.mean_squared_error(labels=targets, predictions=predictions)
total_loss = tf.losses.get_total_loss()
val_loss = tf.losses.mean_squared_error(labels=val_targets, predictions=val_predictions,
loss_collection="validation"
)
4)根据需要定义train_step_fn()
以触发验证循环
VALIDATION_INTERVAL = 1000 . # validate every 1000 steps
# slim.learning.train(train_step_fn=train_step_fn)
def train_step_fn(sess, train_op, global_step, train_step_kwargs):
"""
slim.learning.train_step():
train_step_kwargs = {summary_writer:, should_log:, should_stop:}
"""
train_step_fn.step += 1 # or use global_step.eval(session=sess)
# calc training losses
total_loss, should_stop = slim.learning.train_step(sess, train_op, global_step, train_step_kwargs)
# validate on interval
if train_step_fn.step % VALIDATION_INTERVAL == 0:
validiate_loss, validation_delta = sess.run([val_loss, summary_validation_delta])
print(">> global step {}: train={} validation={} delta={}".format(train_step_fn.step,
total_loss, validiate_loss, validiate_loss-total_loss))
return [total_loss, should_stop]
train_step_fn.step = 0
5)将train_step_fn()
添加到训练循环中
# Run the training inside a session.
final_loss = slim.learning.train(
train_op,
train_step_fn=train_step_fn,
...
)
中的完整结果