如何在chainer v2.0中修复学习网络中的图层参数?

时间:2017-10-23 12:54:09

标签: chainer

假设我在其他数据库上预先训练了网络。由于过度拟合问题,我的数据库不是很多样化,因此过度拟合问题非常严重。我想在chainer v2.0中加载预训练的网络参数但要修复前几个层,在这种情况下,我应该在chainer1.0中使用什么,我知道在chainer1.0中有volatile关键字但在v2中已弃用0.0。

在前几个图层中进行处理时,我应该在chainer.no_backprop_mode():内使用def __call__吗?

1 个答案:

答案 0 :(得分:0)

是的,您可以在前向计算代码中使用chainer.no_backprop_mode()上下文管理器来修复特定图层的参数。这是一个例子:

def __call__(self, x):
    with chainer.no_backprop_mode():
        h1 = F.relu(self.l1(x))
    h2 = F.relu(self.l2(h1))
    return self.l3(h2)