如何在Torch nn包中禁用omp?

时间:2015-05-20 22:56:28

标签: lua openmp torch

具体来说,当输入张量的大小很小时,我希望nn.LogSoftMax不使用omp。我有一个小脚本来测试运行时间。

require 'nn'
my_lsm = function(t)
    o = torch.zeros((#t)[1])
    sum = 0.0
    for i = 1,(#t)[1] do
        o[i] = torch.exp(t[i])
        sum = sum + o[i]
    end
    o = o / sum
    return torch.log(o)
end

ii=torch.randn(arg[1])
m=nn.LogSoftMax()

timer = torch.Timer()
timer:stop()
timer:reset()
timer:resume()
my_lsm(ii)
print(timer:time().real)

timer:stop()
timer:reset()
timer:resume()
m:forward(ii)
print(timer:time().real)

如果arg[1]为10,那么我的基本日志softmax函数运行得更快:

0.00021696090698242
0.033425092697144

但是一旦arg[1]达到10,000,000,omp确实有很多帮助:

29.561321973801 
0.11547803878784

所以我怀疑omp开销非常高。如果我的代码必须使用小输入多次调用log softmax(称张量大小仅为3),则会花费太多时间。有没有办法在某些情况下手动禁用omp使用(但不总是)?

1 个答案:

答案 0 :(得分:4)

  

在某些情况下(但并非总是),有没有办法手动禁用omp使用?

如果您真的想这样做,可以使用torch.setnumthreadstorch.getnumthreads这样的话:

local nth = torch.getnumthreads()
torch.setnumthreads(1)
-- do something
torch.setnumthreads(nth)

所以你可以按照以下方式修补nn.LogSoftMax

nn.LogSoftMax.updateOutput = function(self, input)
  local nth = torch.getnumthreads()
  torch.setnumthreads(1)
  local out = input.nn.LogSoftMax_updateOutput(self, input)
  torch.setnumthreads(nth)
  return out
end