在TensorFlow1.4中进行分布式训练:当最终操作(OP)不是优化程序时,分布式训练不起作用。我该如何运行?

时间:2020-08-05 13:30:38

标签: python tensorflow distributed

我正准备使用tensorflow集群来训练推荐模型。推荐模型的最新更新方法不是基于梯度的,所以我手动计算模型响应参数,然后tf.assign参数。代码如下:

def single_meta_task(self,indata,global_step,scope_name):
    meta_train_predict_1=[]
    meta_train_loss_1=[]
    task_update_ops_1=[]
    meta_train_target_1=[]
    # task_1
    input=indata[0]
    meta_train_predict_1.append(self.basemodel.forward(input[0][0]))
    self.weights = self.basemodel.get_para(scope_name)
    with tf.control_dependencies(list(self.weights.values())):
        meta_train_target_1.append(input[0][1])
        meta_train_loss_1.append(tf.reduce_mean(self.basemodel.loss_fn(meta_train_target_1[0], meta_train_predict_1[0])))
        task_update_ops_1.append(self.task_optimizer.minimize(loss=meta_train_loss_1[0],global_step=global_step))
    # if update more than one step:
    for i in range(1, self._num_updates):
        with tf.control_dependencies([task_update_ops_1[i-1]]):
            meta_train_predict_1.append(self.basemodel.forward(input[i][0]))
            meta_train_target_1.append(input[i][1])
            meta_train_loss_1.append(tf.reduce_mean(self.basemodel.loss_fn(meta_train_target_1[i], meta_train_predict_1[i])))
            task_update_ops_1.append(self.task_optimizer.minimize(loss=meta_train_loss_1[i],global_step=global_step))
    with tf.control_dependencies([task_update_ops_1[-1]]):
        self.update_weights_1=self.basemodel.get_para(scope_name)
        self.new_weights_1=self.get_new_weights(self.weights,scope_name,self.lr_maml) # the update parameter function
        weight_update_op=self.basemodel.set_para(self.new_weights,scope_name)

我在带有tensorflow集群的tf.train.MonitoredTrainingSession中运行 weight_update_op ,但是我得到了

step: 38 meta_train_loss: 0.6931472 auc: 0.0 auc total: 0.5 time: 57.1792571545
step: 39 meta_train_loss: 0.6931472 auc: 0.5 auc total: 0.5 time: 18.3862159252
step: 40 meta_train_loss: 0.6931472 auc: 0.5 auc total: 0.5 time: 11.3142118454
step: 41 meta_train_loss: 0.6931472 auc: 0.5 auc total: 0.5 time: 12.2699699402

whitch表示参数不会更新,图形也不会流动。

我将分发火车更改为本地单火车,通过 tf.Session ,相同的代码效果很好。

有人可以告诉我为什么吗?以及如何在最终OP不是Optimizer的情况下使用分布式训练?

0 个答案:

没有答案