我最近开始使用Torch框架和Lua脚本语言来玩神经网络。我已经掌握了线性网络的基础知识,所以我尝试了一些更复杂但更简单的方法:
我的想法是我有3个输入,我必须选择前两个,除以它们,然后将结果转发给线性模块。所以,我制作了这个小脚本:
require "nn";
require "optim";
local N = 3;
local input = torch.Tensor{
{1, 2, 3},
{9, 20, 20},
{9, 300, 1},
};
local output = torch.Tensor(N);
for i=1, N do
output[i] = 1;
end
local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());
local criterion = nn.BCECriterion();
local params, gradParams = ratioPerceptron:getParameters();
local optimState = {learningRate = 0.01};
local maxIteration = 100000;
for i=1, maxIteration do
local function f(params)
gradParams:zero();
local outputs = ratioPerceptron:forward(input);
local loss = criterion:forward(outputs, output);
local dloss_doutputs = criterion:backward(outputs, output);
ratioPerceptron:backward(input, dloss_doutputs);
return loss, gradParams;
end
optim.sgd(f, params, optimState);
end
在训练期间调用向后时出现错误:
CDivTable.lua:21:torch.LongStorage和(null)都没有加法运算符
但是如果我从顺序模块中删除CDivTable,并将nn.Reshape和nn.Linear更改为二维输入(因为我们删除了CDivTable,它将二维输入除以产生一个暗淡的输出),如下所示:
local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.Reshape(N, 2));
ratioPerceptron:add(nn.Linear(2, 1));
ratioPerceptron:add(nn.Sigmoid());
训练完成且没有错误...有没有其他方法可以划分两个选定的输入并将结果转发给线性模块?
答案 0 :(得分:1)
模块CDivTable
将表作为输入,并将第一个表的元素除以第二个表的元素。在这里,您将网络作为单个输入提供,而不是两个输入的表。这就是我相信null
错误的原因。 Torch无法理解您的输入(包含两个向量)应该被视为两个向量的表。它只看到一个大小2x3
的张量!因此,您必须告诉Torch从输入中创建一个表。因此,您可以使用模块SplitTable(dim)
,该模块会将输入拆分为维度dim
中的表格。
在窄模块之后插入此行ratioPerceptron:add(nn.SplitTable(1))
:
local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.SplitTable(1))
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());
此外,当您遇到此类错误时,我建议您通过放置print
语句来查看网络计算的内容:在添加创建错误的模块的行之前插入一行print(ratioPerceptron:forward(input))