我是火炬手和lua的新手(因为任何一直关注我最新帖子的人都可以证明:)并且对gmodule对象的转发函数(类nngraph)有以下问题。
根据源代码(https://github.com/torch/nn/blob/master/Module.lua - 因为类gmodule继承自nn.module)语法为:
function Module:forward(input)
return self:updateOutput(input)
end
但是,我发现了一个表作为输入传递的情况,如:
local lst = clones.rnn[t]:forward{x[{{}, t}], unpack(rnn_state[t-1])}
其中:
clones.rnn[t]
本身就是一个gmodule对象。反过来,rnn_state [t-1]是一个有4个张量的表。所以最后,我们有类似于
的东西result_var = gmodule:forward{[1]=tensor_1,[2]=tensor_2,[3]=tensor_3,...,[5]=tensor_5}
问题是,根据网络架构,你可以传递输入 - 格式化为表 - 不仅是输入层,还有隐藏层?
在这种情况下,您必须检查每层只传递一个输入? (输出层除外)
非常感谢
答案 0 :(得分:0)
我终于找到了答案。模块类(以及继承的类gmodule)具有输入和输出。
然而,输入(以及输出)不需要是矢量,但它可以是矢量的集合 - 这取决于神经网络配置,在这种特殊情况下它是一个相当复杂的递归神经网络。
因此,如果网络有多个输入向量,您可以这样做:
result_var = gmodule:forward{[1]=tensor_1,[2]=tensor_2,[3]=tensor_3,...,[5]=tensor_5}
其中每个张量/向量是输入向量之一。这些矢量中只有一个是X矢量或特征矢量。其他可以作为其他中间节点的输入。
反过来,result_var(输出)可以有一个输出作为张量(预测)或一组张量作为输出(张量集合),具体取决于网络配置。
如果是后者,那么其中一个输出张量就是预测,提醒通常在下一个时间步骤中用作中间节点的输入 - 但这又取决于网络配置。