SpatialDepthWiseConvolution Xception反向传播错误

时间:2018-02-23 12:57:33

标签: lua neural-network deep-learning torch

我正在尝试实现受Xception创意启发的NN。 无法理解我的模型有什么问题...

local torch = require 'torch'
local nn = require 'nn'

dofile('GlobalAveragePooling.lua')

local model = nn.Sequential()
-- Entry convolution 
model:add( nn.SpatialConvolution(3, 64, 3, 3, 2, 2, 1, 1) )
model:add( nn.SpatialBatchNormalization(64) )
model:add( nn.ReLU() )

-- Xception Unit with "skip-path"
local seq = nn.Sequential()
seq:add( nn.SpatialDepthWiseConvolution(64, 1, 3, 3, 1, 1, 1, 1) )
seq:add( nn.SpatialConvolution(64, 512, 1, 1, 1, 1, 0, 0) )
seq:add( nn.SpatialBatchNormalization(512) )
seq:add( nn.SpatialMaxPooling(3, 3, 2, 2, 1, 1) )

local con = nn.ConcatTable()
con:add( nn.SpatialConvolution(64, 512, 1, 1, 2, 2, 0, 0) )
con:add( seq )

model:add( con )
model:add( nn.CAddTable() )
model:add( nn.ReLU() )

-- Exit fully-connected layers for softmax(3) output
model:add( nn.GlobalAveragePooling() )
model:add( nn.Reshape(512) )

model:add( nn.Linear(512, 3) )
model:add( nn.LogSoftMax() )

print(tostring(model))

local X = torch.randn(10, 3, 16, 8)
local Y = torch.LongTensor(10):random(1,3)

local criterion = nn.ClassNLLCriterion()

local Yhat = model:forward(X)

local loss = criterion:forward(Yhat, Y)
local gradLoss = criterion:backward(Yhat, Y)
model:backward(X, gradLoss)

该模型适用于forward()步骤。 但是当涉及到模型时失败:向后(X,gradLoss)有错误:

    /nn/THNN.lua:110: Need gradOutput of dimension 5 and gradOutput.size[3] == 8 but got gradOutput to be of shape: [10 x 64 x 1 x 4 x 8] at ../THNN/generic/SpatialDepthWiseConvolution.c:53
stack traceback:
[C]: in function 'v'    ../nn/THNN.lua:110: in function 'SpatialDepthWiseConvolution_updateGradInput'   ../nn/SpatialDepthWiseConvolution.lua:80:
in function 'updateGradInput' ../Module.lua:31:
in function <../nn/Module.lua:29>
[C]: in function 'xpcall'   ../nn/Container.lua:63:
in function 'rethrowErrors'     ../nn/Sequential.lua:88:
in function <../nn/Sequential.lua:78>
[C]: in function 'xpcall'   ../Container.lua:63:
in function 'rethrowErrors'     ../nn/ConcatTable.lua:66:
in function <../ConcatTable.lua:30>
[C]: in function 'xpcall'   ../nn/Container.lua:63:
in function 'rethrowErrors'     ../nn/Sequential.lua:84:
in function 'backward'  test.lua:45:
in main chunk   [C]: at 0x00405d50

1 个答案:

答案 0 :(得分:0)

事实证明,问题在于火炬/ nn中的SpatialDepthWiseConvolution的低级实现。我创建了一个问题:https://github.com/torch/nn/issues/1307

目前(2018年3月3日),这个问题尚未解决。 当然,我希望有人会在低级实现中修复错误。 但是现在我知道有两种方法可以解决这个问题:

  • 通过火炬容器模仿
  • 使用我的纯粹实施

以下是使用容器 Concat 并行以及标准 SpatialConvolution 模块模拟此模块的方法:

  local depth_wise_conv = nn.Concat(2)
  for o = 1, nOutputPlane do
    local out = nn.Parallel(2, 2)
    for i = 1, nInputPlane do
      local seq = nn.Sequential()
      local conv = nn.SpatialConvolution(1, 1, kW, kH, dW, dH, pW, pH):noBias()
      seq:add( nn.Reshape(1, inputHeight, inputWidth) )
      seq:add( conv )
      out:add( seq )
    end
    depth_wise_conv:add( out )
  end

请注意,上面代码中的depth_wise_conv模块应该使用4维进行批量输入:batchSize x nInputPlane x inputHeight x inputWidth

但我也花了一些时间并创建了一个纯粹的lua实现 SpatialDepthWiseConvolution 作为模块。你可以在这里找到它: https://gist.github.com/diovisgood/36ce5a6c5e9dd4cb20b13dd2a28c1f71

还有: SpatialConvolution 实施和单元测试。 我在MIT许可下发布了这些模块,所以任何人都可以使用它们。

请注意,这些模块尚未经过严格测试! 所以任何帮助或建议都表示赞赏。

还有一件事。有一篇很好的论文解释了Vincent Dumoulin和Francesco Visin的卷积和转置卷积: 'A guide to convolution arithmetic for deep learning'