深度学习工具箱中DBN的错误结果

时间:2015-04-16 03:14:19

标签: matlab deep-learning

我想要this example。当我使用mnist_uint8数据时,我可以很好地运行此代码。但如果我使用自己的数据运行一个模型(如DBN.m),则此代码为:

[er, bad] = nntest(nn, test_x, test_y); 

将无效,er为零。为什么会这样?我的训练数据的输入大小为320 * 200,输出为320 * 1.

编辑:添加了代码和数据文件

load dataX 
load dataY 
load pdataX 
load pdataY
train_x=dataX/100
test_x=pdataX/100
pdataY(find(pdataY(:,:)<=20))=0;
pdataY(find(pdataY(:,:)>20))=1;
dataY(find(dataY(:,:)<=20))=0;
dataY(find(dataY(:,:)>20))=1;
train_y=dataY
test_y=pdataY
rand('state',0); 
dbn.sizes = [100 40];

%train a 100-40 hidden unit DBN 
opts.numepochs = 1;
opts.batchsize = 40;
opts.momentum = 0;
opts.alpha = 1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);

%unfold dbn to nn nn = dbnunfoldtonn(dbn, 1);
nn.activation_function = 'sigm';
%train nn opts.numepochs = 1;
opts.batchsize = 40;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);

这是数据

https://mega.co.nz/#!9V5wmKYK!q3nAvrzKZCT_Q3Ae-DDNAGDnV57b6Pzq6gtf01w0lD8

1 个答案:

答案 0 :(得分:1)

修改

经过多次讨论(见评论),问题是目标(y)需要使用N-of-N编码格式进行训练和测试。例如,类1的[1 0]和类2的[0 1]。修改的代码产生0.2125的基本错误率。进一步调整和架构更改应该会产生更好的结果。

clear all

load dataX 
load dataY 
load pdataX 
load pdataY
train_x=dataX/100;
test_x=pdataX/100;
pdataY(find(pdataY(:,:)<=20))=0;
pdataY(find(pdataY(:,:)>20))=1;
dataY(find(dataY(:,:)<=20))=0;
dataY(find(dataY(:,:)>20))=1;
train_y=dataY
test_y=pdataY

% Add dimension for one-of-N encoding
train_y(:,2) = 1-train_y(:,1);
test_y(:,2) = 1-test_y(:,1);

rand('state',0)
dbn.sizes = [100 40];

%train a 100-40 hidden unit DBN
opts.numepochs = 2;
opts.batchsize = 40;
opts.momentum = 0;
opts.alpha = 1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);

%unfold dbn to nn
nn = dbnunfoldtonn(dbn, 2);
nn.activation_function = 'sigm';

%train nn
opts.numepochs = 100;
opts.batchsize = 40;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);

原始答案:

我假设您的训练数据是200个功能和320个训练样例。假设您正确训练,那么您可能需要执行功能缩减。我知道在MNIST数据集上运行的ML算法很受欢迎,它使用主成分分析(参见Matlab函数pca())来截断一些特征。请为我们发布更多代码以实际查看问题。