从torch.Tensor

时间:2017-07-29 12:07:35

标签: lua torch

我在lua中有以下代码写的。

我想获得scores的N个最高分数及其相应分数的索引。

看起来我必须从scores迭代删除当前最大值并再次检索最大值,但无法找到合适的方法。

nqs=dataset['question']:size(1);
scores=torch.Tensor(nqs,noutput);
qids=torch.LongTensor(nqs);
for i=1,nqs,batch_size do
    xlua.progress(i, nqs)
    r=math.min(i+batch_size-1,nqs);
    scores[{{i,r},{}}],qids[{{i,r}}]=forward(i,r);
--    print(scores)
end

tmp,pred=torch.max(scores,2);

1 个答案:

答案 0 :(得分:1)

我希望我没有误会,因为你展示的代码(特别是foor循环)似乎并不真正想要你想做的事情。不管怎么说,我就是这样做的。

 sr=scores:view(-1,scores:size(1)*scores:size(2))
 val,id=sr:sort()
 --val is a row vector with the values stored in increasing order
 --id will be the corresponding index in sr
 --now you can slice val and id from the end to find the N values you want, then you can recover the original index in the scores matrix simply with
 col=(index-1)%scores:size(2)+1
 row=math.ceil(index/scores:size(2))

希望这会有所帮助。