将tf.train.SyncReplicasOptimizer与多个优化器

时间:2018-01-31 11:23:58

标签: tensorflow deep-learning

我正在尝试在分布式设置中运行DeepLab Resnet(https://github.com/DrSleep/tensorflow-deeplab-resnet)。我选择了与Inception的分布式训练示例中演示的类似的同步数据并行训练方法。(https://github.com/tensorflow/models/tree/master/research/inception)。

在示例中,tf.train.SyncReplicasOptimizer用于聚合使用单个RMSPropOptimizer的每个worker的渐变。 global_step变量也由函数更新。以下片段描述了这种典型案例:

# Optimizer in every worker that performs gradient descent. 
opt=tf.train.RMSPropOptimizer(lr,RMSPROP_DECAY,
momentum=RMSPROP_MOMENTUM, 
epsilon=RMSPROP_EPSILON)

# A synchronous replica optimizer that wraps the RMSoptimizer
  opt = tf.train.SyncReplicasOptimizer(opt,
      replicas_to_aggregate=num_replicas_to_aggregate,
      total_num_replicas=num_workers,
      variable_averages=exp_moving_averager,
      variables_to_average=variables_to_average)

#Apply gradients and update global_step
apply_gradients_op =opt.apply_gradients(grads,global_step=global_step) 

在我的案例中,使用了三个具有不同学习率的优化器,每个优化器用于网络的特定部分。

#Three optimizers declared with different learning rates
opt_conv = tf.train.MomentumOptimizer(learning_rate, args.momentum)
opt_fc_w = tf.train.MomentumOptimizer(learning_rate * 10.0, args.momentum)
opt_fc_b = tf.train.MomentumOptimizer(learning_rate * 20.0, args.momentum)

#Scope for every optimizer 
grads = tf.gradients(reduced_loss, conv_trainable + fc_w_trainable + fc_b_trainable)
grads_conv = grads[:len(conv_trainable)]
grads_fc_w = grads[len(conv_trainable) : (len(conv_trainable) + len(fc_w_trainable))]
grads_fc_b = grads[(len(conv_trainable) + len(fc_w_trainable)):]

#Gradients applied to various portions of the network
train_op_conv = opt_conv.apply_gradients(zip(grads_conv, conv_trainable))
train_op_fc_w = opt_fc_w.apply_gradients(zip(grads_fc_w, fc_w_trainable))
train_op_fc_b = opt_fc_b.apply_gradients(zip(grads_fc_b, fc_b_trainable)) 

#train_op groups all optimizer operations
train_op = tf.group(train_op_conv, train_op_fc_w, train_op_fc_b)

我不知道如何将tf.train.SyncReplicasOptimizer用于多个优化器。另外,我不知道如何更新global_step变量并在这种情况下使用chief_queue_runner。请帮帮我。

0 个答案:

没有答案