mod.predict提供比预期更多的列

时间:2017-03-07 06:29:45

标签: python mxnet

我在IRIS数据集上使用MXNet,它具有4个功能,并将花分类为 - ' setosa',' versicolor' virginica'。我的训练数据有89行。我的标签数据是89列的行向量。我将花名编码为数字-0,1,2,因为看起来mx.io.NDArrayIter不接受带有字符串值的numpy ndarray。然后我尝试使用

进行预测

re = mod.predict(test_iter)

我得到一个形状为14 * 10的结果。 当我只有3个标签时,为什么我会得到10列,如何将这些结果映射到我的标签。预测结果如下所示:

  

[[0.11760861 0.12082944 0.1207106 0.09154381 0.09155304 0.09155869   0.09154817 0.09155204 0.09154914 0.09154641] [0.1176083 0.12082954 0.12071151 0.09154379 0.09155323 0.09155825   0.0915481 0.09155164 0.09154923 0.09154641] [0.11760829 0.1208293 0.12071083 0.09154385 0.09155313 0.09155875   0.09154838 0.09155186 0.09154932 0.09154625] [0.11760861 0.12082901 0.12071037 0.09154388 0.09155303 0.09155875   0.09154829 0.09155209 0.09154959 0.09154641] [0.11760896 0.12082863 0.12070955 0.09154405 0.09155299 0.09155875   0.09154839 0.09155225 0.09154996 0.09154646] [0.1176089 0.1208287 0.1207095 0.09154407 0.09155297 0.09155882   0.09154844 0.09155232 0.09154989 0.0915464] [0.11760896 0.12082864 0.12070941 0.09154408 0.09155297 0.09155882   0.09154844 0.09155234 0.09154993 0.09154642] [0.1176088 0.12082874 0.12070983 0.09154399 0.09155302 0.09155872   0.09154837 0.09155215 0.09154984 0.09154641] [0.11760852 0.12082904 0.12071032 0.09154394 0.09155304 0.09155876   0.09154835 0.09155209 0.09154959 0.09154631] [0.11760963 0.12082832 0.12070873 0.09154428 0.09155257 0.09155893   0.09154856 0.09155177 0.09155051 0.09154671] [0.11760966 0.12082829 0.12070868 0.09154429 0.09155258 0.09155892   0.09154858 0.0915518 0.09155052 0.09154672] [0.11760949 0.1208282 0.12070852 0.09154446 0.09155259 0.09155893   0.09154854 0.09155205 0.0915506 0.09154666] [0.11760952 0.12082817 0.12070853 0.0915444 0.09155261 0.09155891   0.09154853 0.09155206 0.09155057 0.09154668] [0.1176096 0.1208283 0.12070892 0.09154423 0.09155267 0.09155882   0.09154859 0.09155172 0.09155044 0.09154676]]

1 个答案:

答案 0 :(得分:1)

使用“y = mod.predict(val_iter,num_batch = 1)”代替“y = mod.predict(val_iter)”,则只能获得一个批处理标签。例如,如果batch_size为10,那么您将只获得10个标签。