TensorFlow参数服务器和批处理规范化问题

时间:2019-10-30 19:10:40

标签: tensorflow gradient batch-normalization parameter-server

我有一个工作程序类,可用于在多个GPU上分配计算。每个工作人员都会计算工作人员模型的梯度,然后将这些梯度应用于中央服务器模型。在许多情况下,此方法效果很好,但是当我使用批处理规范化时,它似乎失败了。

这是我的ParameterWorker逻辑:

class ParameterWorker:

    def __init__(self, sess, scope, model, iterations, optimizer):
        self.sess = sess
        self.scope = scope
        self.model = model
        self.iterations = iterations
        self.optimizer = optimizer

        self.worker_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope)
        self.server_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="server")

        self.worker_gradients = tf.gradients(self.model.loss, self.worker_parameters)
        self.update_server = self.optimizer.apply_gradients(zip(self.worker_gradients, self.server_parameters))

        self.update_worker = []
        for worker_parameter, server_parameter in zip(self.worker_parameters, self.server_parameters):
            self.update_worker.append(worker_parameter.assign(server_parameter))

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, self.scope)
        with tf.control_dependencies(update_ops):
            self.minimizer = self.optimizer.minimize(self.model.loss)

    def work(self):
        for i in range(self.iterations):
            feed = self.feed()
            loss, accuracy, _ = self.sess.run([self.model.loss, self.model.accuracy, self.minimizer], feed_dict=feed)
            print("({}) iteration {} loss {:.8f} accuracy {:.8f}".format(self.scope, i, loss, accuracy), flush=True)
            self.sess.run(self.update_server, feed_dict=feed)
            self.sess.run(self.update_worker)

尤其是,这似乎并不适用于TensorFlow批处理规范化层中的渐变:

output = tf.layers.batch_normalization(output, training=self.training)

这是将批量归一化参数的梯度应用于服务器的正确方法,还是我应该采取另一步骤?还有tf.control_dependencies要执行的另一步吗?

server_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, "server")
with tf.control_dependencies(server_update_ops):
    # do something here with worker gradients?

0 个答案:

没有答案