如何在火炬中向图形模块添加其他图层

时间:2016-10-31 13:24:07

标签: lua neural-network torch

如何在火炬的nngraph包中向图形模块(gModule)添加新节点?我尝试使用add函数,这将节点添加到gModules对象的模块插槽中。但是,输出仍然来自前一个节点。

简化代码:

require "nn"
require "nngraph"

-- Function that builds a gModule
function buildModule(input_size,hidden_size)
    local x = nn.Identity()()
    local out = x - nn.Linear(input_size,hidden_size) - nn.Tanh()
    return nn.gModule({x},{out})
end

network = buildModule(5,3)
-- Additional layer to add
l2 = nn.Linear(3,10)
network:add(l2)

-- Expected a tensor of size 10 but got one with size 3
print(network:forward(torch.randn(5)))

1 个答案:

答案 0 :(得分:1)

gModule实际上不应该被改变。它支持的事实:add实际上是作为nn.Container的子类的副作用,而不是设计决策。通常,一旦创建了gModule,就不应该修改它的内部结构,因为你必须修改一些内部属性才能使其工作。相反 - 如果你想在“顶部”添加一些东西,只需定义新容器,将前一个容器作为输入。

-- Function that builds a gModule
function buildModule(input_size,hidden_size)
    local x = nn.Identity()()
    local out = x - nn.Linear(input_size,hidden_size) - nn.Tanh()
    return nn.gModule({x},{out})
end

network = buildModule(5,3)

new_network = nn.Sequential()
new_network:add(network)
new_network:add(nn.Linear(3,10))