Torch:在输出的子集上计算的损失的反向传播

时间:2017-02-03 04:02:29

标签: neural-network torch

我有一个简单的卷积神经网络,其输出是单通道4x4特征映射。在训练期间,(回归)损失需要仅计算16个输出中的单个值。 此值的位置将在正向传递之后决定。如何从这一个输出计算损耗,同时确保在反向传输期间所有不相关的梯度都为零。

假设我在火炬中有以下简单模型:

require 'nn'

-- the input
local batch_sz = 2
local x = torch.Tensor(batch_sz, 3, 100, 100):uniform(-1,1)

-- the model
local net = nn.Sequential()
net:add(nn.SpatialConvolution(3, 128, 9, 9, 9, 9, 1, 1))
net:add(nn.SpatialConvolution(128, 1, 3, 3, 3, 3, 1, 1))
net:add(nn.Squeeze(1, 3))

print(net)

-- the loss (don't know how to employ it yet)
local loss = nn.SmoothL1Criterion()

-- forward'ing x through the network would result in a 2x4x4 output
y = net:forward(x)

print(y)

我看过nn.SelectTable,看起来如果我将输出转换成表格形式,我就可以实现我想要的东西了吗?

1 个答案:

答案 0 :(得分:0)

这是我目前的解决方案。它的工作原理是将输出拆分为一个表,然后使用nn.SelectTable():backward()来获得完整的渐变:

require 'nn'

-- the input
local batch_sz = 2
local x = torch.Tensor(batch_sz, 3, 100, 100):uniform(-1,1)

-- the model
local net = nn.Sequential()
net:add(nn.SpatialConvolution(3, 128, 9, 9, 9, 9, 1, 1))
net:add(nn.SpatialConvolution(128, 1, 3, 3, 3, 3, 1, 1))
net:add(nn.Squeeze(1, 3))

-- convert output into a table format
net:add(nn.View(1, -1))         -- vectorize
net:add(nn.SplitTable(1, 1))    -- split all outputs into table elements

print(net)

-- the loss
local loss = nn.SmoothL1Criterion()

-- forward'ing x through the network would result in a (2)x4x4 output
y = net:forward(x)

print(y)

-- returns the output table's index belonging to specific location
function get_sample_idx(feat_h, feat_w, smpl_idx, feat_r, feat_c)
    local idx = (smpl_idx - 1) * feat_h * feat_w
    return idx + feat_c + ((feat_r - 1) * feat_w)
end

-- I want to back-propagate the loss of this sample at this feature location
local smpl_idx = 2
local feat_r = 3
local feat_c = 4
-- get the actual index location in the output table (for a 4x4 output feature map)
local out_idx = get_sample_idx(4, 4, smpl_idx, feat_r, feat_c)

-- the (fake) ground-truth
local gt = torch.rand(1)

-- compute loss on the selected feature map location for the selected sample
local err = loss:forward(y[out_idx], gt)
-- compute loss gradient, as if there was only this one location
local dE_dy = loss:backward(y[out_idx], gt)
-- now convert into full loss gradient (zero'ing out irrelevant losses)
local full_dE_dy = nn.SelectTable(out_idx):backward(y, dE_dy)
-- do back-prop through who network
net:backward(x, full_dE_dy)

print("The full dE/dy")
print(table.unpack(full_dE_dy))

我真的很感激有人指出一种更简单或更有效的方法。