感知器和多层感知器具有1个隐藏节点,用于解决XOR

时间:2016-05-26 14:08:39

标签: lua neural-network torch

这些天我在玩Torch7。

今天,我实施了Perceptron和Multilayer perceptorn(MLP)来解决XOR问题。

正如预期的那样,MLP在XOR上运行良好,而Perceptron则不然。

但我很好奇,如果隐藏节点的数量为1,结果是什么。

我预计MLP的结果可能与Perceptorn相同,因为它只有1个隐藏节点。<​​/ p>

但有趣的是,MLP比Percentron更好。

更多细节,Perceptron得到0.25错误(如预期的那样),但带有1个隐藏节点的MLP得到大约0.16错误。

我认为一个隐藏节点在问题空间中充当一行。

因此,如果只有一个隐藏节点,它可能与Perceptron相同。

但是这个结果告诉我我错了。

现在,我想知道为什么带有1个隐藏节点的MLP比Perceptron好。

请教我这个结果的原因。

非常感谢。

以下是Perceptron代码:

-- perceptron

require 'nn'

-- data
data = torch.Tensor({ {0, 0}, {0, 1}, {1, 0}, {1, 1} })
-- target
target = torch.Tensor({ 0, 1, 1, 0 })

-- model
perceptron = nn.Linear(2, 1)
-- loss function
criterion = nn.MSECriterion()

-- training
for i = 1, 10000 do
   -- set gradients to zero
   perceptron:zeroGradParameters()
   -- compute output
   output = perceptron:forward(data)
   -- compute loss
   loss = criterion:forward(output, target)
   -- compute gradients w.r.t. output
   dldo = criterion:backward(output, target)
   -- compute gradients w.r.t. parameters
   perceptron:backward(data,dldo)
   -- gradient descent with learningRate = 0.1
   perceptron:updateParameters(0.1)
   print(loss)
end

以下是带有1个隐藏节点代码的MLP:

-- multilayer perceptron

require 'nn'

-- data
data = torch.Tensor({ {0, 0}, {0, 1}, {1, 0}, {1, 1} })
-- target
target = torch.Tensor({ 0, 1, 1, 0 })

-- model
multilayer = nn.Sequential()
inputs = 2; outputs = 1; HUs = 1;
multilayer:add(nn.Linear(inputs, HUs))
multilayer:add(nn.Tanh())
multilayer:add(nn.Linear(HUs, outputs))
-- loss function
criterion = nn.MSECriterion()

-- training
for i = 1, 10000 do
   -- set gradients to zero
   multilayer:zeroGradParameters()
   -- compute output
   output = multilayer:forward(data)
   -- compute loss
   loss = criterion:forward(output, target)
   -- compute gradients w.r.t. output
   dldo = criterion:backward(output, target)
   -- compute gradients w.r.t. parameters
   multilayer:backward(data,dldo)
   -- gradient descent with learningRate = 0.1
   multilayer:updateParameters(0.1)
   print(loss)
end

0 个答案:

没有答案