为什么在这里用mc dropout更新贝叶斯cnn的实现中的目标?
def update_target(target, original, update_rate):
for target_param, param in zip(target.parameters(), original.parameters()):
target_param.data.copy_((1.0 - update_rate) * target_param.data + update_rate*param.data)
答案 0 :(得分:1)
您提到的实现是并行的数据。
这意味着,作者打算训练具有相同体系结构但具有不同超参数的多个网络。
尽管以一种非常规的方式,update_target
的作用是:
update_target(net_test, net, 0.001)
与net相比,它以较低的学习速率更新net_test,但对实际正在训练的原始net进行完全相同的参数更改。只有变化尺度不同。
我认为这在计算效率方面是有用的,因为在主要训练阶段实际上只对其中一个网络进行了“训练”:
outputs = net(inputs)
loss = CE(outputs, labels)
loss.backward()
optimizer.step()
每步前进次数减少一次,后退次数减少一次。