我正在尝试通过lua / torch7实施本文https://arxiv.org/pdf/1804.06962.pdf
在前进过程中,我没问题,但在后退过程modele.gapbranch:backward(n, loss_grad)
中,我遇到了这个错误:
/home/narimene/distro/install/bin/luajit: ...e/narimene/distro/install/share/lua/5.1/nn/Container.lua:67: In 2 module of nn.Sequential: /home/narimene/distro/install/share/lua/5.1/nn/Concat.lua:92: bad argument #1 to 'narrow' (number expected, got nil) stack traceback: [C]: in function 'narrow' /home/narimene/distro/install/share/lua/5.1/nn/Concat.lua:92: in function </home/narimene/distro/install/share/lua/5.1/nn/Concat.lua:47> [C]: in function 'xpcall' ...e/narimene/distro/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors' .../narimene/distro/install/share/lua/5.1/nn/Sequential.lua:84: in function 'backward' gap2.lua:240: in function 'opfunc' /home/narimene/distro/install/share/lua/5.1/optim/sgd.lua:44: in function 'sgd' gap2.lua:247: in main chunk [C]: in function 'dofile' ...ene/distro/install/lib/luarocks/rocks/trepl/scm-1/bin/th:150: in main chunk [C]: at 0x563fabe66570 WARNING: If you see a stack trace below, it doesn't point to the place where this error occurred. Please use only the one above. stack traceback: [C]: in function 'error' ...e/narimene/distro/install/share/lua/5.1/nn/Container.lua:67: in function 'rethrowErrors' .../narimene/distro/install/share/lua/5.1/nn/Sequential.lua:84: in function 'backward' gap2.lua:240: in function 'opfunc' /home/narimene/distro/install/share/lua/5.1/optim/sgd.lua:44: in function 'sgd' gap2.lua:247: in main chunk [C]: in function 'dofile' ...ene/distro/install/lib/luarocks/rocks/trepl/scm-1/bin/th:150: in main chunk [C]: at 0x563fabe66570
这是代码(gap2.lua):
require 'nn'
require 'cunn'
require 'cutorch'
local GapBranch, Parent = torch.class('nn.GapBranch', 'nn.Module')
function GapBranch:__init(label, num_classes, args, threshold)
Parent.__init(self)
self.gt_labels = label
num_classes = num_classes ~= nil and num_classes or 10
self.threshold = threshold or 0.6
self.gapbranch = nn.Sequential()
self.gapbranch:add(nn.SpatialConvolution(3,512, 3, 3, 1, 1, 1, 1)) -- cette ligne est a enlever
self.cls = self:classifier(512, num_classes)
self.cls_erase = self:classifier(512, num_classes)
self.gapbranch:add(nn.Concat():add(self.cls):add(self.cls_erase))
--self.gapbranch:add(self.cls_erase)
--Optimizer
self.loss_cross_entropy = nn.CrossEntropyCriterion():cuda()
end
function GapBranch:classifier(in_planes, out_planes)
gapcnn = nn.Sequential()
gapcnn:add(nn.SpatialConvolution(in_planes, 1024, 3, 3, 1, 1, 1, 1))
gapcnn:add(nn.ReLU())
gapcnn:add(nn.SpatialConvolution(1024, 1024, 3, 3, 1, 1, 1, 1))
gapcnn:add(nn.ReLU())
gapcnn:add(nn.SpatialConvolution(1024,out_planes, 1, 1, 1,1))
return gapcnn
end
function mulTensor(tensor1, tensor2)
newTensor = torch.Tensor(tensor1:size()):cuda()
for i=1, tensor1:size()[1] do
for j=1, tensor1:size()[2] do
newTensor[{i,j}] = torch.cmul(tensor1[{i,j}],tensor2[{i,1}])
end
end
return newTensor
end
function GapBranch:erase_feature_maps(atten_map_normed, feature_maps, threshold)
if #atten_map_normed:size()>3 then
atten_map_normed = torch.squeeze(atten_map_normed)
end
atten_shape = atten_map_normed:size()
pos = torch.ge(atten_map_normed, threshold)
mask = torch.ones(atten_shape):cuda() -- cuda
mask[pos] = 0.0
m = nn.Unsqueeze(2)
m = m:cuda()
mask = m:forward(mask)
erased_feature_maps = mulTensor(feature_maps,mask) -- Variable
return erased_feature_maps
end
function GapBranch:normalize_atten_maps(atten_map)
atten_shape = atten_map:size()
batch_mins, _ = torch.min(atten_map:view(atten_shape[1],-1),2)
batch_maxs, _ = torch.max(atten_map:view(atten_shape[1],-1),2)
atten_normed = torch.cdiv(atten_map:view(atten_shape[1],-1)-batch_mins:expandAs(atten_map:view(atten_shape[1],-1)), (batch_maxs - batch_mins):expandAs(atten_map:view(atten_shape[1],-1)))
atten_normed = atten_normed:view(atten_shape)
return atten_normed
end
function GapBranch:get_atten_map(feature_maps, gt_labels, normalize)
normalize = normalize or true
label = gt_labels:long()
feature_map_size = feature_maps:size()
batch_size = feature_map_size[1]
atten_map = torch.zeros(feature_map_size[1], feature_map_size[3], feature_map_size[4])
atten_map = atten_map:cuda()
for batch_idx = 1, batch_size do
-- label.data[batch_idx]
--label[batch_idx]
print('label ',label:size())
print('feature_maps ', feature_maps:size())
atten_map[{batch_idx}] = torch.squeeze(feature_maps[{batch_idx,label[batch_idx]}])
end
if normalize then
atten_map = self:normalize_atten_maps(atten_map)
end
return atten_map
end
function GapBranch:gaplayer()
gaplayer = nn.Sequential()
gaplayer:add(nn.SpatialZeroPadding(1, 1, 1 ,1))
gaplayer:add(nn.SpatialAveragePooling(3, 3, 1, 1))
return gaplayer
end
function GapBranch:updateOutput(input) -- need label
-- Backbone
feat = self.gapbranch:get(1):forward(input)
self.gap = self:gaplayer()
self.gap:cuda()
feat3 = self.gap:forward(feat)
m = nn.Unsqueeze(2)
m = m:cuda()
-- Branch A
out = self.gapbranch:get(2):get(1):forward(feat3)
self.map1 = out
logits_1 = torch.squeeze(torch.mean(torch.mean(out, 3), 4))
logits_1 = m:forward(logits_1)
print('logits_1 ',logits_1:size())
--feat5 = self.gapbranch:get(2):get(2):forward(feat3)
localization_map_normed = self:get_atten_map(out, self.gt_labels, true)
self.attention = localization_map_normed
feat_erase = self:erase_feature_maps(localization_map_normed, feat3, self.threshold)
-- Branch B
out_erase = self.gapbranch:get(2):get(2):forward(feat_erase)
self.map_erase = out_erase
logits_ers = torch.squeeze(torch.mean(torch.mean(out_erase, 3), 4))
m = nn.Unsqueeze(2)
m = m:cuda()
logits_ers = m:forward(logits_ers)
print('logits_ers ', logits_ers:size())
return {logits_1, logits_ers}
end
function GapBranch:get_loss(resModele, gt_labels)
--[[ if self.onehot == 'True' then
gt = gt_labels:float()
else
gt = gt_labels:long()
end
--]]
print('resModele ', resModele[1])
loss_cls = self.loss_cross_entropy:forward(resModele[1], gt_labels)
loss_cls_ers = self.loss_cross_entropy:forward(resModele[2], gt_labels)
loss_val = loss_cls + loss_cls_ers
return {loss_val, }
end
require 'paths'
if (not paths.filep("cifar10torchsmall.zip")) then
os.execute('wget -c https://s3.amazonaws.com/torch7/data/cifar10torchsmall.zip')
os.execute('unzip cifar10torchsmall.zip')
end
trainset = torch.load('cifar10-train.t7')
testset = torch.load('cifar10-test.t7')
classes = {'airplane', 'automobile', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}
-- ignore setmetatable for now, it is a feature beyond the scope of this tutorial. It sets the index operator.
setmetatable(trainset,
{__index = function(t, i)
return {t.data[i], t.label[i]}
end}
);
trainset.data = trainset.data:double() -- convert the data from a ByteTensor to a DoubleTensor.
function trainset:size()
return self.data:size(1)
end
mean = {} -- store the mean, to normalize the test set in the future
stdv = {} -- store the standard-deviation for the future
for i=1,3 do -- over each image channel
mean[i] = trainset.data[{ {}, {i}, {}, {} }]:mean() -- mean estimation
print('Channel ' .. i .. ', Mean: ' .. mean[i])
trainset.data[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction
stdv[i] = trainset.data[{ {}, {i}, {}, {} }]:std() -- std estimation
print('Channel ' .. i .. ', Standard Deviation: ' .. stdv[i])
trainset.data[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling
end
trainset.data = trainset.data:cuda()
trainset.label = trainset.label:cuda()
modele = nn.GapBranch(trainset.label):cuda()
modele.gapbranch = modele.gapbranch:cuda()
print(modele.gapbranch)
theta, gradTheta = modele.gapbranch:getParameters()
optimState = {learningRate = 0.15}
require 'optim'
for epoch = 1, 1 do
function feval(theta)
for i=1, 1 do
modele.gapbranch:zeroGradParameters()
m = nn.Unsqueeze(1)
m = m:cuda()
n = m:forward(trainset.data[i])
h = modele:forward(n)
j = modele:get_loss(h,trainset.label[i])
loss_cls_grad = modele.loss_cross_entropy:backward(h[1],trainset.label[i])
loss_cls_ers_grad = modele.loss_cross_entropy:backward(h[2],trainset.label[i])
loss_grad = loss_cls_grad + loss_cls_ers_grad
loss_grad = torch.randn(1,10,32,32):cuda()
modele.gapbranch:backward(n, loss_grad)
end
return j, gradTheta
end
print('***************************')
optim.sgd(feval, theta, optimState)
end
如果有人能帮助我,我将非常感激