我正在运行具有以下结构的代码:
network = createNetwork() -- loading a pre-trained network.
function train()
for i=1,#trainingsamples do
local ip = loadInput()
local ip_1 = someImageProcessing(ip)
local ip_2 = someImageProcessing(ip)
network:forward( ...some manipulation on ip_1,ip_2...)
network:backward()
collectgarbage('collect')
print debug.getlocal -- all local variables.
end
end
我期待collectgarbage()将释放ip_1,ip_2和ip所拥有的所有内存。但我可以看到内存没有被释放。这会导致内存泄漏。我想知道发生了什么。有人可以帮我理解collectgarbage()的奇怪行为并修复内存泄漏。
我真的很抱歉我无法添加完整的代码。希望我添加的代码片段足以理解我的代码流程,我的网络培训代码与标准的CNN培训代码非常相似。
编辑:
很抱歉没有提及变量被声明为本地变量,并在示例代码段中使用变量的关键字。我现在已经编辑过了。唯一的全局变量是在列车功能之外声明的网络,我将ip_1,ip_2作为网络的输入。另外,我在下面添加了我的实际代码的修剪版本。
network = createNetwork()
function trainNetwork()
local parameters,gradParameters = network:getParameters()
network:training() -- set flag for dropout
local bs = 1
local lR = params.learning_rate / torch.sqrt(bs)
local optimConfig = {learningRate = params.learning_rate,
momentum = params.momentum,
learningRateDecay = params.lr_decay,
beta1 = params.optim_beta1,
beta2 = params.optim_beta2,
epsilon = params.optim_epsilon}
local nfiles = getNoofFiles('train')
local weights = torch.Tensor(params.num_classes):fill(1)
criterion = nn.ClassNLLCriterion(weights)
for ep=1,params.epochs do
IMAGE_SEQ = 1
while (IMAGE_SEQ <= nfiles) do
xlua.progress(IMAGE_SEQ, nfiles)
local input, inputd2
local color_image, depth_image2, target_image
local nextInput = loadNext('train')
color_image = nextInput.data.rgb
depth_image2 = nextInput.data.depth
target_image = nextInput.data.labels
input = network0:forward(color_image) -- process RGB
inputd2 = networkd:forward(depth_image2):squeeze() -- HHA
local input_concat = torch.cat(input,inputd2,1):squeeze() -- concat RGB, HHA
collectgarbage('collect')
target = target_image:reshape(params.imWidth*params.imHeight) -- reshape target as vector
-- create closure to evaluate f(X) and df/dX
local loss = 0
local feval = function(x)
-- get new parameters
if x ~= parameters then parameters:copy(x) end
collectgarbage()
-- reset gradients
gradParameters:zero()
-- f is the average of all criterions
-- evaluate function for complete mini batch
local output = network:forward(input_concat) -- run forward pass
local err = criterion:forward(output, target) -- compute loss
loss = loss + err
-- estimate df/dW
local df_do = criterion:backward(output, target)
network:backward(input_concat, df_do) -- update parameters
local _,predicted_labels = torch.max(output,2)
predicted_labels = torch.reshape(predicted_labels:squeeze():float(),params.imHeight,params.imWidth)
return err,gradParameters
end -- feval
pm('Training loss: '.. loss, 3)
_,current_loss = optim.adam(feval, parameters, optimConfig)
print ('epoch / current_loss ',ep,current_loss[1])
os.execute('cat /proc/$PPID/status | grep RSS')
collectgarbage('collect')
-- for memory leakage debugging
print ('locals')
for x, v in pairs(locals()) do
if type(v) == 'userdata' then
print(x, v:size())
end
end
print ('upvalues')
for x,v in pairs(upvalues()) do
if type(v) == 'userdata' then
print(x, v:size())
end
end
end -- ii
print(string.format('Loss: %.4f Epoch: %d grad-norm: %.4f',
current_loss[1], ep, torch.norm(parameters)/torch.norm(gradParameters)))
if (current_loss[1] ~= current_loss[1] or gradParameters ~= gradParameters) then
print ('nan loss or gradParams. quiting...')
abort()
end
-- some validation code here
end --epochs
print('Training completed')
end
答案 0 :(得分:4)
正如@Adam在评论中所说,in_1
和in_2
变量继续被引用,其值不能被垃圾收集。即使您将它们更改为局部变量,它们也不会在那时被垃圾收集,因为它们所定义的块尚未关闭。
您可以执行的操作是在调用in_1
之前将in_2
和nil
值设置为collectgarbage
,这样可以使先前分配的值无法访问且符合垃圾条件采集。只有在没有可能存储相同值的其他变量时,这才有效。
答案 1 :(得分:0)
+1保罗在上面的回答;但请注意&#34;应该&#34;。几乎所有的时间你都会好起来的。但是,如果是你的代码变得更加复杂(并且你开始传递内存对象并对它们进行处理),你可能会发现Lua gc偶尔可能决定保留一个内存对象只是比预期的更长一点。但是不要担心(或浪费时间试图找出原因),最终所有未使用的内存objs将由Lua gc收集。垃圾收集器是一种复杂的算法,有时可能看起来有点不确定。
答案 2 :(得分:0)
您可以创建全局变量来存储值。所以这些变量一直都是可用的。因此,在重写值之前,这样的变量gc无法收集它们。 只需使vars本地化并从范围调用gc。 GC的第一个循环也可以只调用终结器和第二个空闲内存。 但不确定。所以你可以尝试两次调用gc。
function train()
do
local in = loadInput()
local in_1 = someImageProcessing(in)
local in_2 = someImageProcessing(in)
network:forward( ...some manipulation on in_1,in_2...)
network:backward()
end
collectgarbage('collect')
collectgarbage('collect')
print debug.getlocal -- all local variables.
PS。 in
在Lua