更新分类图像中的目标

时间:2020-07-20 15:29:41

标签: pytorch

为什么在这里用mc dropout更新贝叶斯cnn的实现中的目标?

https://github.com/sungyubkim/MCDO/blob/master/Bayesian_CNN_with_MCDO.ipynb?fbclid=IwAR18IMLcdUUp90TRoYodsJS7GW1smk-KGYovNpojn8LtRhDQckFI_gnpOYc

def update_target(target, original, update_rate):
        for target_param, param in zip(target.parameters(), original.parameters()):
            target_param.data.copy_((1.0 - update_rate) * target_param.data + update_rate*param.data)

1 个答案:

答案 0 :(得分:1)

您提到的实现是并行的数据。

这意味着,作者打算训练具有相同体系结构但具有不同超参数的多个网络。

尽管以一种非常规的方式,update_target的作用是:

update_target(net_test, net, 0.001)

与net相比,它以较低的学习速率更新net_test,但对实际正在训练的原始net进行完全相同的参数更改。只有变化尺度不同。

我认为这在计算效率方面是有用的,因为在主要训练阶段实际上只对其中一个网络进行了“训练”:

outputs = net(inputs)
loss = CE(outputs, labels)
loss.backward()
optimizer.step()

每步前进次数减少一次,后退次数减少一次。