我正在使用C++ DLib训练简单的MLP(输入层为8单位,隐藏层为50单位),并使用以下89个样本作为训练数据:
float template_user[89][8]={
{0.083651,0.281587,0.370476,0.704444,0.253865,0.056415,0.002344,0.465187},
{0.142540,0.272857,0.376032,0.740952,0.591501,0.227614,0.000000,0.832224},
{0.095625,0.258750,0.447500,0.779792,0.449932,0.000964,0.035104,0.606591},
{0.115208,0.181250,0.478750,0.797083,0.491855,0.015824,0.011652,0.649632},
{0.107436,0.166026,0.509872,0.764231,0.689760,0.266863,0.017919,0.861909},
{0.083077,0.156838,0.506838,0.754701,0.730839,0.073274,0.012962,0.901468},
{0.055965,0.213158,0.468421,0.772982,0.564555,0.020303,0.065653,0.779779},
{0.139667,0.250833,0.452833,0.721667,0.449828,0.005464,0.036171,0.574080},
{0.069885,0.193563,0.474713,0.738621,0.702248,0.067716,0.020190,0.853911},
{0.142879,0.228636,0.387424,0.696212,0.521491,0.215815,0.000000,0.743419},
{0.072727,0.174924,0.537197,0.854545,0.751153,0.162381,0.015300,0.829441},
{0.112857,0.274444,0.487619,0.720952,0.280613,0.005150,0.036518,0.327060},
{0.075208,0.105625,0.638958,0.738958,0.720347,0.033513,0.034740,0.741035},
{0.084314,0.156863,0.540980,0.773137,0.849637,0.045651,0.142826,0.928612},
{0.101453,0.168034,0.570940,0.816068,0.782684,0.051173,0.016929,0.894334},
{0.139298,0.242281,0.433333,0.707719,0.453023,0.019736,0.016849,0.650420},
{0.141667,0.275417,0.483125,0.774375,0.348158,0.007721,0.016579,0.382547},
{0.059048,0.210238,0.444762,0.725238,0.501227,0.028792,0.015904,0.658447},
{0.134127,0.220000,0.474127,0.704762,0.556249,0.022907,0.018309,0.640852},
{0.060741,0.130278,0.653519,0.817685,0.908317,0.036471,0.062615,0.926337},
{0.069425,0.160920,0.502414,0.814483,0.665313,0.053639,0.088214,0.919472},
{0.096543,0.176420,0.547654,0.746420,0.885374,0.062998,0.260892,0.970567},
{0.131410,0.209487,0.472180,0.774487,0.687548,0.314074,0.019767,0.784370},
{0.089487,0.277436,0.465641,0.746410,0.511571,0.004442,0.049289,0.586354},
{0.113509,0.277018,0.416491,0.747193,0.471137,0.026170,0.032863,0.567805},
{0.106167,0.236000,0.390167,0.725500,0.222446,0.012844,0.009073,0.319552},
{0.154697,0.246061,0.361667,0.702424,0.387337,0.130159,0.000000,0.599886},
{0.139848,0.265758,0.371667,0.706061,0.420752,0.063272,0.010781,0.730907},
{0.123333,0.222917,0.457917,0.736111,0.197867,0.018106,0.011322,0.362068},
{0.118333,0.276167,0.391667,0.678667,0.173529,0.000000,0.082133,0.321224},
{0.117361,0.275278,0.369722,0.690000,0.293287,0.132657,0.000000,0.400215},
{0.059167,0.228958,0.433542,0.771667,0.288781,0.009255,0.039148,0.557826},
{0.117576,0.272576,0.372121,0.732121,0.453184,0.190272,0.000000,0.549434},
{0.110400,0.285733,0.484933,0.716133,0.357000,0.014651,0.029153,0.399618},
{0.113636,0.292273,0.424848,0.675909,0.265669,0.000000,0.086476,0.485617},
{0.129111,0.201333,0.398444,0.798667,0.610667,0.391739,0.039324,0.934257},
{0.130702,0.278947,0.488947,0.781228,0.288551,0.002304,0.048600,0.447999},
{0.062424,0.198182,0.406061,0.754545,0.545461,0.048464,0.031975,0.953561},
{0.116042,0.260833,0.397500,0.752917,0.239495,0.019980,0.009038,0.484188},
{0.099545,0.208939,0.452424,0.705303,0.366764,0.037443,0.041399,0.541036},
{0.147917,0.288542,0.519167,0.791250,0.342535,0.000000,0.054587,0.423326},
{0.151528,0.270694,0.376944,0.690417,0.479157,0.261519,0.026807,0.734895},
{0.072222,0.252667,0.380889,0.774000,0.354927,0.084983,0.002151,0.587992},
{0.107037,0.257778,0.390185,0.732222,0.366828,0.156880,0.046430,0.517064},
{0.132500,0.239583,0.386250,0.730417,0.418626,0.072691,0.001799,0.481614},
{0.138788,0.268939,0.359697,0.701818,0.300690,0.143796,0.000000,0.468543},
{0.159130,0.247101,0.360580,0.667681,0.306328,0.177979,0.000000,0.451001},
{0.114694,0.180204,0.522449,0.777211,0.517179,0.240061,0.009289,0.675917},
{0.081037,0.190370,0.496963,0.781630,0.752979,0.194462,0.003395,0.887803},
{0.144561,0.257719,0.406491,0.666316,0.347379,0.010121,0.030200,0.539781},
{0.101961,0.257059,0.378235,0.781569,0.548395,0.226968,0.000000,0.908654},
{0.148000,0.279167,0.362500,0.738833,0.384442,0.113496,0.000000,0.523396},
{0.131746,0.282381,0.359048,0.731429,0.373971,0.010582,0.054514,0.359725},
{0.120476,0.264762,0.372857,0.745952,0.265676,0.121189,0.000000,0.553113},
{0.134444,0.264815,0.360556,0.787407,0.273869,0.088453,0.000000,0.446707},
{0.082963,0.268148,0.453889,0.759815,0.246077,0.013733,0.037363,0.399243},
{0.121449,0.250725,0.397391,0.680145,0.202441,0.000000,0.048686,0.362474},
{0.095417,0.265833,0.384583,0.756250,0.200351,0.032333,0.002281,0.435691},
{0.153485,0.286667,0.377879,0.733030,0.578530,0.314558,0.000000,0.747655},
{0.095694,0.255139,0.425139,0.687917,0.214292,0.007446,0.045981,0.619005},
{0.104800,0.285467,0.369067,0.707200,0.390821,0.112631,0.000000,0.690165},
{0.141667,0.222407,0.396481,0.743704,0.423090,0.092248,0.006714,0.954501},
{0.117451,0.252941,0.361177,0.718431,0.297293,0.077985,0.000000,0.706839},
{0.120702,0.272632,0.364737,0.745088,0.239333,0.060230,0.000777,0.385973},
{0.130556,0.287639,0.370833,0.695139,0.392796,0.216649,0.000000,0.702700},
{0.117333,0.246111,0.397889,0.702778,0.597046,0.363898,0.025308,0.830660},
{0.130000,0.279630,0.370000,0.750741,0.343966,0.183980,0.000000,0.445425},
{0.161154,0.242821,0.383718,0.710256,0.577957,0.156384,0.000000,0.958507},
{0.121282,0.242564,0.388077,0.673974,0.226179,0.014283,0.018323,0.406824},
{0.126061,0.293333,0.380303,0.752727,0.434234,0.057875,0.010695,0.825712},
{0.108833,0.216500,0.381000,0.726500,0.529321,0.116078,0.000926,0.878521},
{0.117500,0.240500,0.370667,0.725667,0.335912,0.197998,0.000000,0.637666},
{0.105447,0.169431,0.528374,0.788374,0.637837,0.127642,0.003151,0.839206},
{0.097917,0.202500,0.414583,0.810000,0.379878,0.048688,0.009243,0.645313},
{0.154062,0.220000,0.367188,0.652396,0.500206,0.403869,0.000000,0.586002},
{0.130351,0.238597,0.374386,0.709824,0.570770,0.319919,0.000000,0.811718},
{0.111667,0.278889,0.412778,0.650000,0.213929,0.005955,0.025740,0.590823},
{0.106296,0.290556,0.473518,0.742963,0.238166,0.000000,0.144586,0.310825},
{0.055463,0.155093,0.515000,0.819815,0.343923,0.084410,0.003949,0.706575},
{0.049306,0.210833,0.368333,0.696389,0.413882,0.123322,0.000000,0.608080},
{0.146863,0.247059,0.369216,0.809608,0.598618,0.181790,0.024801,0.940989},
{0.097917,0.223750,0.380000,0.674167,0.366680,0.040337,0.015240,0.660683},
{0.110667,0.226333,0.413833,0.715167,0.319241,0.107831,0.000000,0.491870},
{0.060682,0.179848,0.480303,0.762046,0.778796,0.361098,0.070085,0.878837},
{0.117451,0.183725,0.403333,0.749020,0.170094,0.005361,0.184146,0.727668},
{0.082407,0.211852,0.432963,0.740926,0.309057,0.003045,0.057845,0.661626},
{0.081404,0.168070,0.411930,0.785263,0.570020,0.027666,0.079844,0.860993},
{0.076019,0.177500,0.548611,0.789444,0.869244,0.033799,0.001967,0.930204},
{0.098437,0.179792,0.491146,0.798542,0.649532,0.100751,0.021504,0.918400}};
其中第一行标记为1(template_user [0]),其他行标记为0(template_user [1]至template_user [88])。
当我尝试预测标签时,基本上总是任何正输入特征向量总是1(或非常接近1)。如果传递带有负值的向量,则看到训练模型预测标签0的唯一方法。知道为什么会这样吗?也许MLP实现不喜欢将标准化的十进制值作为输入吗?
下面是我的完整代码:
#include <iostream>
#include <dlib/mlp.h>
using namespace std;
using namespace dlib;
float template_user[89][8]={
{0.083651,0.281587,0.370476,0.704444,0.253865,0.056415,0.002344,0.465187},
{0.142540,0.272857,0.376032,0.740952,0.591501,0.227614,0.000000,0.832224},
{0.095625,0.258750,0.447500,0.779792,0.449932,0.000964,0.035104,0.606591},
{0.115208,0.181250,0.478750,0.797083,0.491855,0.015824,0.011652,0.649632},
{0.107436,0.166026,0.509872,0.764231,0.689760,0.266863,0.017919,0.861909},
{0.083077,0.156838,0.506838,0.754701,0.730839,0.073274,0.012962,0.901468},
{0.055965,0.213158,0.468421,0.772982,0.564555,0.020303,0.065653,0.779779},
{0.139667,0.250833,0.452833,0.721667,0.449828,0.005464,0.036171,0.574080},
{0.069885,0.193563,0.474713,0.738621,0.702248,0.067716,0.020190,0.853911},
{0.142879,0.228636,0.387424,0.696212,0.521491,0.215815,0.000000,0.743419},
{0.072727,0.174924,0.537197,0.854545,0.751153,0.162381,0.015300,0.829441},
{0.112857,0.274444,0.487619,0.720952,0.280613,0.005150,0.036518,0.327060},
{0.075208,0.105625,0.638958,0.738958,0.720347,0.033513,0.034740,0.741035},
{0.084314,0.156863,0.540980,0.773137,0.849637,0.045651,0.142826,0.928612},
{0.101453,0.168034,0.570940,0.816068,0.782684,0.051173,0.016929,0.894334},
{0.139298,0.242281,0.433333,0.707719,0.453023,0.019736,0.016849,0.650420},
{0.141667,0.275417,0.483125,0.774375,0.348158,0.007721,0.016579,0.382547},
{0.059048,0.210238,0.444762,0.725238,0.501227,0.028792,0.015904,0.658447},
{0.134127,0.220000,0.474127,0.704762,0.556249,0.022907,0.018309,0.640852},
{0.060741,0.130278,0.653519,0.817685,0.908317,0.036471,0.062615,0.926337},
{0.069425,0.160920,0.502414,0.814483,0.665313,0.053639,0.088214,0.919472},
{0.096543,0.176420,0.547654,0.746420,0.885374,0.062998,0.260892,0.970567},
{0.131410,0.209487,0.472180,0.774487,0.687548,0.314074,0.019767,0.784370},
{0.089487,0.277436,0.465641,0.746410,0.511571,0.004442,0.049289,0.586354},
{0.113509,0.277018,0.416491,0.747193,0.471137,0.026170,0.032863,0.567805},
{0.106167,0.236000,0.390167,0.725500,0.222446,0.012844,0.009073,0.319552},
{0.154697,0.246061,0.361667,0.702424,0.387337,0.130159,0.000000,0.599886},
{0.139848,0.265758,0.371667,0.706061,0.420752,0.063272,0.010781,0.730907},
{0.123333,0.222917,0.457917,0.736111,0.197867,0.018106,0.011322,0.362068},
{0.118333,0.276167,0.391667,0.678667,0.173529,0.000000,0.082133,0.321224},
{0.117361,0.275278,0.369722,0.690000,0.293287,0.132657,0.000000,0.400215},
{0.059167,0.228958,0.433542,0.771667,0.288781,0.009255,0.039148,0.557826},
{0.117576,0.272576,0.372121,0.732121,0.453184,0.190272,0.000000,0.549434},
{0.110400,0.285733,0.484933,0.716133,0.357000,0.014651,0.029153,0.399618},
{0.113636,0.292273,0.424848,0.675909,0.265669,0.000000,0.086476,0.485617},
{0.129111,0.201333,0.398444,0.798667,0.610667,0.391739,0.039324,0.934257},
{0.130702,0.278947,0.488947,0.781228,0.288551,0.002304,0.048600,0.447999},
{0.062424,0.198182,0.406061,0.754545,0.545461,0.048464,0.031975,0.953561},
{0.116042,0.260833,0.397500,0.752917,0.239495,0.019980,0.009038,0.484188},
{0.099545,0.208939,0.452424,0.705303,0.366764,0.037443,0.041399,0.541036},
{0.147917,0.288542,0.519167,0.791250,0.342535,0.000000,0.054587,0.423326},
{0.151528,0.270694,0.376944,0.690417,0.479157,0.261519,0.026807,0.734895},
{0.072222,0.252667,0.380889,0.774000,0.354927,0.084983,0.002151,0.587992},
{0.107037,0.257778,0.390185,0.732222,0.366828,0.156880,0.046430,0.517064},
{0.132500,0.239583,0.386250,0.730417,0.418626,0.072691,0.001799,0.481614},
{0.138788,0.268939,0.359697,0.701818,0.300690,0.143796,0.000000,0.468543},
{0.159130,0.247101,0.360580,0.667681,0.306328,0.177979,0.000000,0.451001},
{0.114694,0.180204,0.522449,0.777211,0.517179,0.240061,0.009289,0.675917},
{0.081037,0.190370,0.496963,0.781630,0.752979,0.194462,0.003395,0.887803},
{0.144561,0.257719,0.406491,0.666316,0.347379,0.010121,0.030200,0.539781},
{0.101961,0.257059,0.378235,0.781569,0.548395,0.226968,0.000000,0.908654},
{0.148000,0.279167,0.362500,0.738833,0.384442,0.113496,0.000000,0.523396},
{0.131746,0.282381,0.359048,0.731429,0.373971,0.010582,0.054514,0.359725},
{0.120476,0.264762,0.372857,0.745952,0.265676,0.121189,0.000000,0.553113},
{0.134444,0.264815,0.360556,0.787407,0.273869,0.088453,0.000000,0.446707},
{0.082963,0.268148,0.453889,0.759815,0.246077,0.013733,0.037363,0.399243},
{0.121449,0.250725,0.397391,0.680145,0.202441,0.000000,0.048686,0.362474},
{0.095417,0.265833,0.384583,0.756250,0.200351,0.032333,0.002281,0.435691},
{0.153485,0.286667,0.377879,0.733030,0.578530,0.314558,0.000000,0.747655},
{0.095694,0.255139,0.425139,0.687917,0.214292,0.007446,0.045981,0.619005},
{0.104800,0.285467,0.369067,0.707200,0.390821,0.112631,0.000000,0.690165},
{0.141667,0.222407,0.396481,0.743704,0.423090,0.092248,0.006714,0.954501},
{0.117451,0.252941,0.361177,0.718431,0.297293,0.077985,0.000000,0.706839},
{0.120702,0.272632,0.364737,0.745088,0.239333,0.060230,0.000777,0.385973},
{0.130556,0.287639,0.370833,0.695139,0.392796,0.216649,0.000000,0.702700},
{0.117333,0.246111,0.397889,0.702778,0.597046,0.363898,0.025308,0.830660},
{0.130000,0.279630,0.370000,0.750741,0.343966,0.183980,0.000000,0.445425},
{0.161154,0.242821,0.383718,0.710256,0.577957,0.156384,0.000000,0.958507},
{0.121282,0.242564,0.388077,0.673974,0.226179,0.014283,0.018323,0.406824},
{0.126061,0.293333,0.380303,0.752727,0.434234,0.057875,0.010695,0.825712},
{0.108833,0.216500,0.381000,0.726500,0.529321,0.116078,0.000926,0.878521},
{0.117500,0.240500,0.370667,0.725667,0.335912,0.197998,0.000000,0.637666},
{0.105447,0.169431,0.528374,0.788374,0.637837,0.127642,0.003151,0.839206},
{0.097917,0.202500,0.414583,0.810000,0.379878,0.048688,0.009243,0.645313},
{0.154062,0.220000,0.367188,0.652396,0.500206,0.403869,0.000000,0.586002},
{0.130351,0.238597,0.374386,0.709824,0.570770,0.319919,0.000000,0.811718},
{0.111667,0.278889,0.412778,0.650000,0.213929,0.005955,0.025740,0.590823},
{0.106296,0.290556,0.473518,0.742963,0.238166,0.000000,0.144586,0.310825},
{0.055463,0.155093,0.515000,0.819815,0.343923,0.084410,0.003949,0.706575},
{0.049306,0.210833,0.368333,0.696389,0.413882,0.123322,0.000000,0.608080},
{0.146863,0.247059,0.369216,0.809608,0.598618,0.181790,0.024801,0.940989},
{0.097917,0.223750,0.380000,0.674167,0.366680,0.040337,0.015240,0.660683},
{0.110667,0.226333,0.413833,0.715167,0.319241,0.107831,0.000000,0.491870},
{0.060682,0.179848,0.480303,0.762046,0.778796,0.361098,0.070085,0.878837},
{0.117451,0.183725,0.403333,0.749020,0.170094,0.005361,0.184146,0.727668},
{0.082407,0.211852,0.432963,0.740926,0.309057,0.003045,0.057845,0.661626},
{0.081404,0.168070,0.411930,0.785263,0.570020,0.027666,0.079844,0.860993},
{0.076019,0.177500,0.548611,0.789444,0.869244,0.033799,0.001967,0.930204},
{0.098437,0.179792,0.491146,0.798542,0.649532,0.100751,0.021504,0.918400}
};
int main() {
typedef matrix<double, 8, 1> sample_type;
sample_type sample;
mlp::kernel_1a_c net(8,50);
cout << "- - - - Trainining started - - - " << endl;
for(int i=0; i<89; i++){
for(int j=0; j<8;j++){
sample(j) = template_user[i][j];
cout << i << ") - sample(" << j << ")=" << sample(j) << endl;
}
if (i == 0) {
net.train(sample,1);
cout << "Trained as label 1" << endl;
} else {
net.train(sample,0);
cout << "Trained as label 0" << endl;
}
}
cout << "- - - - Trainining ended - - - " << endl;
// - - - - - - - - -
// Testing the MLP
sample_type sample2;
// template_user[0]
sample2(0) = 0.083651;
sample2(1) = 0.281587;
sample2(2) = 0.370476;
sample2(3) = 0.704444;
sample2(4) = 0.253865;
sample2(5) = 0.056415;
sample2(6) = 0.002344;
sample2(7) = 0.465187;
cout << "This user (#0) should be close to label 1 and it is classified as a label " << net(sample2) << endl;
// template_user[1]
sample2(0) = 0.142540;
sample2(1) = 0.272857;
sample2(2) = 0.376032;
sample2(3) = 0.740952;
sample2(4) = 0.591501;
sample2(5) = 0.227614;
sample2(6) = 0.000000;
sample2(7) = 0.832224;
cout << "This user (#1) should be close to label 0 and it is classified as a label " << net(sample2) << endl;
sample2(0) = -1110;
sample2(1) = -1110;
sample2(2) = -1110;
sample2(3) = -1110;
sample2(4) = -1110;
sample2(5) = -1110;
sample2(6) = -1110;
sample2(7) = -1110;
cout << "This user should be close to label 0 and it is classified as a label " << net(sample2) << endl;
}
这是输出的最后一部分:
- - - - Trainining ended - - -
This user (#0) should be close to label 1 and it is classified as a label 1
This user (#1) should be close to label 0 and it is classified as a label 1
This user should be close to label 0 and it is classified as a label 0.467338