火炬nngraph建筑节点与nn.Linear

时间:2015-09-01 06:31:00

标签: lua torch

嗨,我是火炬/ lua的新手,我正在做牛津机器学习课程的practical5

我想要实现的是一个简单的层: m = x1 + x2 cmul线性(x3) 其中cmul是元素乘法而线性只是一个线性层。

我的代码如下:

-- params for the linear layer
params = {
    x3_size1 = 10,
    x3_size2 = 30
}
-- dummy nodes to take input data as nodes in graph
x1 = nn.Identity()()
x2 = nn.Identity()()
x3 = nn.Identity()()

-- modeling output = x1 + x2 cmul linear(x3)
l3 = nn.Linear(params.x3_size1, params.x3_size2)(x1)
m23 = nn.CMulTable()({x2,l3})
add = nn.CAddTable()({x1, m23})

-- specify the inputs and outputs of the graph
m = nn.gModule({x1,x2,x3}, {add})

graph.dot(mlp.fg, "mlp")

但是,我收到了错误消息:

  /Users/yiranzhang/torch/install/bin/luajit: /Users/yiranzhang/torch/install/share/lua/5.1/nn/Linear.lua:36: attempt to index local 'input' (a nil value)
stack traceback:
    /Users/yiranzhang/torch/install/share/lua/5.1/nn/Linear.lua:36: in function 'forward'
    /Users/yiranzhang/torch/install/share/lua/5.1/nn/Module.lua:232: in function </Users/yiranzhang/torch/install/share/lua/5.1/nn/Module.lua:231>
    [C]: at 0x0156d0d0
    practical5.lua:32: in main chunk
    [C]: in function 'dofile'
    ...hang/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:131: in main chunk
    [C]: at 0x01013242e0

如果我只想要     aa = nn.Linear(10,20)()

我得到了与上面相同的错误。

即使我按照火炬github上的example

我得到了同样的错误。

更新并已解决:

我想念导入包。虽然nngraphnn在代码中都被称为nn,但它们实际上是不同的包。

应该

require 'nngraph'

我只做了

require 'nn'

1 个答案:

答案 0 :(得分:0)

此行中的最后一个参数应为x3而不是x1:

l3 = nn.Linear(params.x3_size1, params.x3_size2)(x3)