在训练期间使用td.cond会导致吞吐量降低

时间:2019-05-14 02:07:52

标签: python tensorflow throughput

在使用resnet50进行图像网络训练的过程中,我们使用LARS更新学习率并在训练的每个步骤中计算LR。训练的吞吐量约为5500。为此,我们打算每隔几步就优化和计算LR操作以提高吞吐量。在原始代码中,我们每步执行compute_lr计算。

我修改了代码,如下所示:

  • Global_step是一个张量,用于观察训练的哪一步;
  • 2是一个常数,表示lr每两步计算一次。

代码:

def compute_lr()
    coumpte_lr 
       ...
    stored_lr
       ...
    return lr
def get_larsvalue()
    get_stored_lr
       ...
    return lr

tf.cond(tf.cast(tf.equal(tf.mod(gg,2),0),tf.bool),lambda:self.compute_lr(),lambda: self.get_larsvalue())

但是在修改代码后,吞吐量下降了。经过分析,我认为这是因为tf.cond不是惰性操作,它将执行两个分支,这显然不是我想要的。我现在不知道该如何编码才能完成我的想法,请大家帮助。

0 个答案:

没有答案