我想通过克隆它并将其放置在几个不同的大型网络中来训练一个小型网络模块。但是当我把它放在一个这样的上下文时,它的参数在训练期间不会更新。下面是一个显示问题的小例子。
我像这样制作小型原始网络:
function mkPrimitive()
local inp = nn.Linear(2, 2)()
local outp = nn.Tanh()(nn.Linear(2, 2)(nn.Tanh()(inp)))
return nn.gModule({inp}, {outp})
end
prim = mkPrimitive()
然后我将它放入一个更大的网络,名为toTrain
,如下所示:
function mkNet()
local fst = prim:clone('weight', 'gradWeight', 'bias', 'gradBias')()
local snd = prim:clone('weight', 'gradWeight', 'bias', 'gradBias')(fst)
return nn.gModule({fst}, {snd})
end
toTrain = mkNet()
然后我训练更大的网络,并打印出它的参数,并在迭代时打印prim
的参数。我看到的是,较大的toTrain
网络参数在训练期间发生了变化,而prim
则没有。以下是培训代码。有办法解决这个问题吗?
numRuns = 10
function train()
local crit = nn.MSECriterion()
for i = 1, numRuns do
toTrain:zeroGradParameters()
local inData = torch.rand(1, 2) --make some input/output data
local outData = torch.rand(1, 2)
local pred = toTrain:forward(inData)
local err = crit:forward(pred, outData)
local grad = crit:backward(pred, outData)
toTrain:backward(inData, grad)
toTrain:updateParameters(0.01)
local bigWs = toTrain:getParameters()
local primWs = prim:getParameters()
print(bigWs) --the params for the big network change during learning,
print(primWs) --but the ones for the primitive don't.
print("------------------------------")
end
end
train()