如何使用CNTK classification_error()?

时间:2017-05-01 18:38:40

标签: cntk

我正在尝试理解cntk.metrics.classification_error()的正确用法,并用它来验证一批针对其基本事实的预测。

以下玩具示例(基于Python API docs):

import numpy as np
from cntk.metrics import classification_error

predictions = np.asarray([[1., 2., 3., 4.],[1., 2., 3., 4.],[1., 2., 3., 4.]], dtype=np.float32)
labels = np.asarray([[0., 0., 0., 1.],[0., 0., 0., 1.],[0., 0., 1., 0.]], dtype=np.float32)
classification_error(predictions, labels).eval()

产生以下结果:

array([[ 0.,  0.,  1.],
       [ 0.,  0.,  1.],
       [ 0.,  0.,  1.]], dtype=float32)

有没有办法可以获得一个向量而不是方形矩阵,因为我想处理一个大批量的方法,它看起来效率很低?

我在调用axis时尝试使用classification_error()关键字,但无论是设置axis=0还是axis=1,我都会得到一个空的结果。

1 个答案:

答案 0 :(得分:2)

这是因为CNTK试图用户友好并最终对类型感到困惑:-)你可以说,因为分类错误甚至不正确。

如果你添加一点点输入信息,它就会获得正确的语义。

p = C.input(4)
y = C.input(4)
classification_error(p, y).eval({p:predictions, y:labels})
array([[ 0.],
       [ 0.],
       [ 1.]], dtype=float32)

我们将努力修复以防止混淆。