调用向后时nn.CDivTable是否抛出错误是否有正当理由?

时间:2017-04-28 12:35:41

标签: machine-learning lua neural-network torch

我最近开始使用Torch框架和Lua脚本语言来玩神经网络。我已经掌握了线性网络的基础知识,所以我尝试了一些更复杂但更简单的方法:

我的想法是我有3个输入,我必须选择前两个,除以它们,然后将结果转发给线性模块。所以,我制作了这个小脚本:

require "nn";
require "optim";

local N = 3;

local input = torch.Tensor{
    {1, 2, 3},
    {9, 20, 20},
    {9, 300, 1},
};

local output = torch.Tensor(N);
for i=1, N do
    output[i] = 1;
end

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());

local criterion = nn.BCECriterion();
local params, gradParams = ratioPerceptron:getParameters();
local optimState = {learningRate = 0.01};

local maxIteration = 100000;
for i=1, maxIteration do
    local function f(params)
        gradParams:zero();

        local outputs = ratioPerceptron:forward(input);
        local loss = criterion:forward(outputs, output);
        local dloss_doutputs = criterion:backward(outputs, output);
        ratioPerceptron:backward(input, dloss_doutputs);

        return loss, gradParams;
    end

    optim.sgd(f, params, optimState);
end

在训练期间调用向后时出现错误:

  

CDivTable.lua:21:torch.LongStorage和(null)都没有加法运算符

但是如果我从顺序模块中删除CDivTable,并将nn.Reshape和nn.Linear更改为二维输入(因为我们删除了CDivTable,它将二维输入除以产生一个暗淡的输出),如下所示:

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.Reshape(N, 2));
ratioPerceptron:add(nn.Linear(2, 1));
ratioPerceptron:add(nn.Sigmoid());

训练完成且没有错误...有没有其他方法可以划分两个选定的输入并将结果转发给线性模块?

1 个答案:

答案 0 :(得分:1)

模块CDivTable将表作为输入,并将第一个表的元素除以第二个表的元素。在这里,您将网络作为单个输入提供,而不是两个输入的表。这就是我相信null错误的原因。 Torch无法理解您的输入(包含两个向量)应该被视为两个向量的表。它只看到一个大小2x3的张量!因此,您必须告诉Torch从输入中创建一个表。因此,您可以使用模块SplitTable(dim),该模块会将输入拆分为维度dim中的表格。

在窄模块之后插入此行ratioPerceptron:add(nn.SplitTable(1))

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.SplitTable(1))
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());

此外,当您遇到此类错误时,我建议您通过放置print语句来查看网络计算的内容:在添加创建错误的模块的行之前插入一行print(ratioPerceptron:forward(input))