Pytorch中的编码器-解码器体系结构

时间:2018-08-16 22:10:42

标签: deep-learning conv-neural-network pytorch torch encoder-decoder

我想使用研究论文来实现卷积编码器解码器架构。 但是这篇研究论文用Torch语言发布了他的代码。

在解码器部分,它使用函数ResizeJoinTable()

up1 = {s5, s4}
        - nn.ResizeJoinTable(2)
        - nn.SpatialConvolution(512, 256, 3, 3, 1, 1, 1, 1)
        - nn.SpatialBatchNormalization(256)
        - nn.ReLU()
        - nn.SpatialConvolution(256, 256, 3, 3, 1, 1, 1, 1)
        - nn.SpatialBatchNormalization(256)
        - nn.ReLU()

这是函数ResizeJoinTable()

local ResizeJoinTable, parent = torch.class('nn.ResizeJoinTable', 'nn.Module')

function ResizeJoinTable:__init(dimension)
    parent.__init(self)
    self.size = torch.LongStorage()
    self.dimension = dimension
    self.gradInput = {}

    self.join = nn.JoinTable(dimension, nil)

    self.model = nn.Sequential()
    local params = {owidth = 1; oheight = 1}
    local parallel = nn.ParallelTable()
    parallel:add(nn.SpatialUpSamplingBilinear(params))
    parallel:add(nn.Identity())
    self.model:add(parallel)
    self.model:add(self.join)

    self.model:float()
    self.model:training()
    self.model:cuda()
end

function ResizeJoinTable:_getPositiveDimension(input)
    return self.join:_getPositiveDimension(input)
end

function ResizeJoinTable:updateOutput(input)
    local second = input[2]

    self.model.modules[1].modules[1].owidth = second:size(4)
    self.model.modules[1].modules[1].oheight = second:size(3)

    return self.model:updateOutput(input)
end

function ResizeJoinTable:clearState()
    self.model:clearState();
end

function ResizeJoinTable:updateGradInput(input, gradOutput)
    self.gradInput = self.model:updateGradInput(input, gradOutput)
    return self.gradInput
end

function ResizeJoinTable:type(type, tensorCache)
    self.gradInput = {}
    return parent.type(self, type, tensorCache)
end

我想知道Torch中的此函数ResizeJoinTable()与Pytorch框架中的max_unpool2d()函数是否匹配。

在此先感谢您的帮助!

0 个答案:

没有答案