火炬 - 用另一个矩阵查询矩阵

时间:2016-02-08 09:52:50

标签: torch

我有一个m x n张量(Tensor 1)和另一个k x 2张量(Tensor 2),我希望使用基于Tensor 2的索引提取Tensor 1的所有值。例如;

Tensor1
  1   2   3   4   5
  6   7   8   9  10
 11  12  13  14  15
 16  17  18  19  20
[torch.DoubleTensor of size 4x5]

Tensor2
 2  1
 3  5
 1  1
 4  3
[torch.DoubleTensor of size 4x2]

功能会产生;

6
15
1
18

1 个答案:

答案 0 :(得分:2)

首先想到的解决方案是简单地遍历索引并选择对应的值:

public static void main(String[] args)
{
     int i;
     //Change this to whatever you want or set it to a argument.
     int repeat = 2;
     for(i = 1; i <= repeat; i++)
     {
          func(a, b, i);
     }
}

这里function get_elems_simple(tensor, indices) local res = torch.Tensor(indices:size(1)):typeAs(tensor) local i = 0 res:apply( function () i = i + 1 return tensor[indices[i]:clone():storage()] end) return res end 只是从多维张量中挑选元素的通用方法。在k维的情况下,这与tensor[indices[i]:clone():storage()]完全相似。

如果您不必提取大量值,此方法可以正常工作(瓶颈是tensor[{indices[i][1], ... , indices[i][k]}]方法,该方法无法使用许多优化技术和SIMD指令,因为它执行的功能是黑色的框)。可以更有效地完成工作:方法:apply完全符合您的需要......具有一维张量。多维目标/指数张量需要展平:

:index

差异很大:

function flatten_indices(sp_indices, shape)
    sp_indices = sp_indices - 1
    local n_elem, n_dim = sp_indices:size(1), sp_indices:size(2)
    local flat_ind = torch.LongTensor(n_elem):fill(1)

    local mult = 1
    for d = n_dim, 1, -1 do
        flat_ind:add(sp_indices[{{}, d}] * mult)
        mult = mult * shape[d]
    end
    return flat_ind
end

function get_elems_efficient(tensor, sp_indices)
    local flat_indices = flatten_indices(sp_indices, tensor:size()) 
    local flat_tensor = tensor:view(-1)
    return flat_tensor:index(1, flat_indices)
end