在使用resnet50进行图像网络训练的过程中,我们使用LARS更新学习率并在训练的每个步骤中计算LR。训练的吞吐量约为5500。为此,我们打算每隔几步就优化和计算LR操作以提高吞吐量。在原始代码中,我们每步执行compute_lr
计算。
我修改了代码,如下所示:
Global_step
是一个张量,用于观察训练的哪一步; 2
是一个常数,表示lr
每两步计算一次。代码:
def compute_lr()
coumpte_lr
...
stored_lr
...
return lr
def get_larsvalue()
get_stored_lr
...
return lr
tf.cond(tf.cast(tf.equal(tf.mod(gg,2),0),tf.bool),lambda:self.compute_lr(),lambda: self.get_larsvalue())
但是在修改代码后,吞吐量下降了。经过分析,我认为这是因为tf.cond
不是惰性操作,它将执行两个分支,这显然不是我想要的。我现在不知道该如何编码才能完成我的想法,请大家帮助。