topk / sort和pick显然是矛盾的结果

时间:2017-07-04 16:38:42

标签: python mxnet

我使用相当标准的softmax输出,使用MXNet模型预测大约100K可能输出中的一个。我想比较分配给真实标签的概率与模型下的最高预测。为了得到前者,我使用了pick操作符;后来我尝试了便宜的版本(topk运营商)和昂贵的版本(sort / argsort + slice)。

在这两种情况下,我都会得到相互矛盾的结果。具体而言,存在许多情况,其中真实标签(用挑选检索)的概率显着高于最高概率输出(用topk / sort检索)。我认为这意味着我做错了什么但不明白什么。并非所有预测都会发生,但它确实存在很大一部分。

有人能给我一些关于发生了什么的提示吗?

代码如下:

for batch in data_iter:
    model.forward(batch, is_train=False)
    predictions = model.get_outputs()[0]
    labels = batch.label[0].as_in_context(predictions.context)

    # scores = mx.nd.topk(predictions, axis=1, k=6, ret_typ='value')
    scores = mx.nd.sort(predictions, axis=1, is_ascend=0)
    scores = mx.nd.slice_axis(scores, axis=1, begin=0, end=6)

    label_score = mx.nd.pick(predictions, labels, axis=1)
    equal = label_score.asnumpy() <= scores.asnumpy()[:, 0]

    if not np.all(equal):
        #I think this should never happen but it does frequently

1 个答案:

答案 0 :(得分:1)

使用MXNet 1.1.0进行测试,以下代码显示问题不会发生:

for _ in range(10):
    predictions = nd.random.uniform(shape=(100, 100000))
    labels = nd.array(np.random.randint(0, 99999, size=(100, 1)))

    scores = mx.nd.sort(predictions, axis=1, is_ascend=0)
    scores = mx.nd.slice_axis(scores, axis=1, begin=0, end=6)

    label_score = mx.nd.pick(predictions, labels, axis=1)
    equal = label_score.asnumpy() <= scores.asnumpy()[:, 0]

    if not np.all(equal):
        print("ERROR")