与使用TF-slim训练相比,无法使用vanilla Tensorflow代码实现相同的性能训练

时间:2018-06-17 23:07:58

标签: python tensorflow resnet tf-slim

以下使用TF-Slim库加载模型并对其进行微调的代码在分类任务中实现了90%的性能(我省略了加载数据和预处理):

with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=0.0001)):
    logits, _ = resnet_v1.resnet_v1_50(images, num_classes=dataset.num_classes, is_training=True)

one_hot_labels = slim.one_hot_encoding(labels, NUM_CLASSES)
tf.losses.softmax_cross_entropy(one_hot_labels, logits)
total_loss = tf.losses.get_total_loss()
global_step = variables.get_or_create_global_step()
lr = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, GAMMA)
optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM)
train_op = slim.learning.create_train_op(total_loss, optimizer, global_step=global_step)
init_fn = slim.assign_from_checkpoint_fn("resnet_v1_50.ckpt", VARIABLES_TO_RESTORE)

final_loss = slim.learning.train( train_op, logdir=train_dir, log_every_n_steps=500, save_summaries_secs=25,  init_fn=init_fn, number_of_steps = NUM_STEPS)

我尝试使用vanilla tensorflow重写相同的代码以更好地控制训练过程,并且由于某些原因,当使用所有相同的超参数(大写)和相同的预处理时,我无法实现相同的性能(10%性能下降)。差异在图表定义中:

        lr = tf.train.exponential_decay(LEARNING_RATE,  global_step, DECAY_STEPS, GAMMA)
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM)
        full_train_op = optimizer.minimize(total_loss, global_step=global_step)

和培训:

for s in range(NUM_STEPS):
    sess.run(train_init_op) #Initializes dataset iterator
    while True:
        try:
            sess.run([full_train_op], feed_dict={is_training: True})                    
        except tf.errors.OutOfRangeError:
            break

超薄列车功能是否在进行其他操作?我认为它可能是使用批量规范化或我在我的代码版本上没有实现的其他东西。

是否可以在tensorflow中加载slim resnet模型并在没有细长列车功能的情况下训练它?我对覆盖train_step_fn不感兴趣。

1 个答案:

答案 0 :(得分:0)

这可能是因为没有与resnet的批量规范相关联update_ops

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM)
with tf.control_dependencies(update_ops):
    full_train_op = optimizer.minimize(total_loss, global_step)
# same training loop