使用torch.serialize两次时,在线程中火炬内存不足

时间:2016-07-17 12:46:54

标签: multithreading lua torch luajit

我尝试向torch-dataframe添加并行数据加载器以添加torchnet compatibility。我已使用tnt.ParallelDatasetIteratorchanged it,因此:

  1. 基本批次在线程外部加载
  2. 批次序列化并发送到主题
  3. 在线程中,批处理被反序列化并将批处理数据转换为张量
  4. 张贴在包含inputtarget键的表格中返回,以匹配tnt.Engine设置。
  5. 问题第二次调用enque时出现错误:.../torch_distro/install/bin/luajit: not enough memory。我目前只与mnist合作mnist-exampleenque循环现在看起来像这样(带有调试内存输出):

    -- `samplePlaceholder` stands in for samples which have been
    -- filtered out by the `filter` function
    local samplePlaceholder = {}
    
    -- The enque does the main loop
    local idx = 1
    local function enqueue()
      while idx <= size and threads:acceptsjob() do
        local batch, reset = self.dataset:get_batch(batch_size)
    
        if (reset) then
          idx = size + 1
        else
          idx = idx + 1
        end
    
        if (batch) then
          local serialized_batch = torch.serialize(batch)
    
          -- In the parallel section only the to_tensor is run in parallel
          --  this should though be the computationally expensive operation
          threads:addjob(
            function(argList)
              io.stderr:write("\n Start");
              io.stderr:write("\n 1: " ..tostring(collectgarbage("count")))
              local origIdx, serialized_batch, samplePlaceholder = unpack(argList)
    
              io.stderr:write("\n 2: " ..tostring(collectgarbage("count")))
              local batch = torch.deserialize(serialized_batch)
              serialized_batch = nil
    
              collectgarbage()
              collectgarbage()
    
              io.stderr:write("\n 3: " .. tostring(collectgarbage("count")))
              batch = transform(batch)
    
              io.stderr:write("\n 4: " .. tostring(collectgarbage("count")))
              local sample = samplePlaceholder
              if (filter(batch)) then
                sample = {}
                sample.input, sample.target = batch:to_tensor()
              end
              io.stderr:write("\n 5: " ..tostring(collectgarbage("count")))
    
              collectgarbage()
              collectgarbage()
              io.stderr:write("\n 6: " ..tostring(collectgarbage("count")))
    
              io.stderr:write("\n End \n");
              return {
                sample,
                origIdx
              }
            end,
            function(argList)
              sample, sampleOrigIdx = unpack(argList)
            end,
            {idx, serialized_batch, samplePlaceholder}
          )
        end
      end
    end
    

    我撒了collectgarbage并尝试删除不需要的任何物品。内存输出相当直接:

     Start
     1: 374840.87695312
     2: 374840.94433594
     3: 372023.79101562
     4: 372023.85839844
     5: 372075.41308594
     6: 372023.73632812
     End 
    

    循环enque的函数是非有序的函数,它是微不足道的(第二个enque抛出了内存错误):

    iterFunction = function()
      while threads:hasjob() do
        enqueue()
        threads:dojob()
        if threads:haserror() then
          threads:synchronize()
        end
        enqueue()
    
        if table.exact_length(sample) > 0 then
          return sample
        end
      end
    end
    

1 个答案:

答案 0 :(得分:1)

所以问题是torch.serialize,其中设置中的函数将整个数据集耦合到函数中。添加时:

serialized_batch = nil
collectgarbage()
collectgarbage()

问题解决了。我进一步想知道是什么占据了这么多的空间,罪魁祸首竟然是我在一个环境中定义了这个函数,这个环境中有一个与函数交织在一起的大数据集,大大增加了它的大小。这里是本地数据的原始定义

mnist = require 'mnist'
local dataset = mnist[mode .. 'dataset']()

-- PROBLEMATIC LINE BELOW --
local ext_resource = dataset.data:reshape(dataset.data:size(1),
  dataset.data:size(2) * dataset.data:size(3)):double()

-- Create a Dataframe with the label. The actual images will be loaded
--  as an external resource
local df = Dataframe(
  Df_Dict{
    label = dataset.label:totable(),
    row_id = torch.range(1, dataset.data:size(1)):totable()
  })

-- Since the mnist package already has taken care of the data
--  splitting we create a single subsetter
df:create_subsets{
  subsets = Df_Dict{core = 1},
  class_args = Df_Tbl({
    batch_args = Df_Tbl({
      label = Df_Array("label"),
      data = function(row)
        return ext_resource[row.row_id]
      end
    })
  })
}

事实证明,删除我突出显示的行会将内存使用量从 358 Mb 降低到 0.0008 Mb !我用来测试性能的代码是:

local mem = {}
table.insert(mem, collectgarbage("count"))

local ser_data = torch.serialize(batch.dataset)
table.insert(mem, collectgarbage("count"))

local ser_retriever = torch.serialize(batch.batchframe_defaults.data)
table.insert(mem, collectgarbage("count"))

local ser_raw_retriever = torch.serialize(function(row)
  return ext_resource[row.row_id]
end)
table.insert(mem, collectgarbage("count"))

local serialized_batch = torch.serialize(batch)
table.insert(mem, collectgarbage("count"))

for i=2,#mem do
  print(i-1, (mem[i] - mem[i-1])/1024)
end

最初产生了输出:

1   0.0082607269287109  
2   358.23344707489 
3   0.0017471313476562  
4   358.90182781219 

并在修复之后:

1   0.0094480514526367  
2   0.00080204010009766 
3   0.00090408325195312 
4   0.010146141052246

我尝试使用setfenv功能,但它没有解决问题。将序列化数据发送到线程仍然存在性能损失,但主要问题已得到解决,如果没有昂贵的数据检索器,则功能要小得多。