放置在较大的模块中时,克隆网络的副本不会进行训练

时间:2015-10-10 14:25:29

标签: lua neural-network torch

我想通过克隆它并将其放置在几个不同的大型网络中来训练一个小型网络模块。但是当我把它放在一个这样的上下文时,它的参数在训练期间不会更新。下面是一个显示问题的小例子。

我像这样制作小型原始网络:

function mkPrimitive()
    local inp  = nn.Linear(2, 2)()
    local outp = nn.Tanh()(nn.Linear(2, 2)(nn.Tanh()(inp)))
    return nn.gModule({inp}, {outp})
end
prim = mkPrimitive()

然后我将它放入一个更大的网络,名为toTrain,如下所示:

function mkNet()
    local fst = prim:clone('weight', 'gradWeight', 'bias', 'gradBias')()
    local snd = prim:clone('weight', 'gradWeight', 'bias', 'gradBias')(fst)
    return nn.gModule({fst}, {snd})
end

toTrain = mkNet()

然后我训练更大的网络,并打印出它的参数,并在迭代时打印prim的参数。我看到的是,较大的toTrain网络参数在训练期间发生了变化,而prim则没有。以下是培训代码。有办法解决这个问题吗?

numRuns = 10
function train()
    local crit = nn.MSECriterion()
    for i = 1, numRuns do
        toTrain:zeroGradParameters()

        local inData = torch.rand(1, 2)  --make some input/output data
        local outData = torch.rand(1, 2)

        local pred = toTrain:forward(inData)
        local err = crit:forward(pred, outData)
        local grad = crit:backward(pred, outData)
        toTrain:backward(inData, grad)
        toTrain:updateParameters(0.01)

        local bigWs = toTrain:getParameters()
        local primWs = prim:getParameters()

        print(bigWs) --the params for the big network change during learning,
        print(primWs) --but the ones for the primitive don't.
        print("------------------------------")
    end
end
train()

1 个答案:

答案 0 :(得分:0)

getParameters()会更改网络参数的内存位置,因此任何共享都会丢失。有关详细信息,请参阅this