如何在火炬的nngraph包中向图形模块(gModule)添加新节点?我尝试使用add函数,这将节点添加到gModules对象的模块插槽中。但是,输出仍然来自前一个节点。
简化代码:
require "nn"
require "nngraph"
-- Function that builds a gModule
function buildModule(input_size,hidden_size)
local x = nn.Identity()()
local out = x - nn.Linear(input_size,hidden_size) - nn.Tanh()
return nn.gModule({x},{out})
end
network = buildModule(5,3)
-- Additional layer to add
l2 = nn.Linear(3,10)
network:add(l2)
-- Expected a tensor of size 10 but got one with size 3
print(network:forward(torch.randn(5)))
答案 0 :(得分:1)
gModule实际上不应该被改变。它支持的事实:add实际上是作为nn.Container的子类的副作用,而不是设计决策。通常,一旦创建了gModule,就不应该修改它的内部结构,因为你必须修改一些内部属性才能使其工作。相反 - 如果你想在“顶部”添加一些东西,只需定义新容器,将前一个容器作为输入。
-- Function that builds a gModule
function buildModule(input_size,hidden_size)
local x = nn.Identity()()
local out = x - nn.Linear(input_size,hidden_size) - nn.Tanh()
return nn.gModule({x},{out})
end
network = buildModule(5,3)
new_network = nn.Sequential()
new_network:add(network)
new_network:add(nn.Linear(3,10))