我有一个m x n张量(Tensor 1)和另一个k x 2张量(Tensor 2),我希望使用基于Tensor 2的索引提取Tensor 1的所有值。例如;
Tensor1
1 2 3 4 5
6 7 8 9 10
11 12 13 14 15
16 17 18 19 20
[torch.DoubleTensor of size 4x5]
Tensor2
2 1
3 5
1 1
4 3
[torch.DoubleTensor of size 4x2]
功能会产生;
6
15
1
18
答案 0 :(得分:2)
首先想到的解决方案是简单地遍历索引并选择对应的值:
public static void main(String[] args)
{
int i;
//Change this to whatever you want or set it to a argument.
int repeat = 2;
for(i = 1; i <= repeat; i++)
{
func(a, b, i);
}
}
这里function get_elems_simple(tensor, indices)
local res = torch.Tensor(indices:size(1)):typeAs(tensor)
local i = 0
res:apply(
function ()
i = i + 1
return tensor[indices[i]:clone():storage()]
end)
return res
end
只是从多维张量中挑选元素的通用方法。在k维的情况下,这与tensor[indices[i]:clone():storage()]
完全相似。
如果您不必提取大量值,此方法可以正常工作(瓶颈是tensor[{indices[i][1], ... , indices[i][k]}]
方法,该方法无法使用许多优化技术和SIMD指令,因为它执行的功能是黑色的框)。可以更有效地完成工作:方法:apply
完全符合您的需要......具有一维张量。多维目标/指数张量需要展平:
:index
差异很大:
function flatten_indices(sp_indices, shape)
sp_indices = sp_indices - 1
local n_elem, n_dim = sp_indices:size(1), sp_indices:size(2)
local flat_ind = torch.LongTensor(n_elem):fill(1)
local mult = 1
for d = n_dim, 1, -1 do
flat_ind:add(sp_indices[{{}, d}] * mult)
mult = mult * shape[d]
end
return flat_ind
end
function get_elems_efficient(tensor, sp_indices)
local flat_indices = flatten_indices(sp_indices, tensor:size())
local flat_tensor = tensor:view(-1)
return flat_tensor:index(1, flat_indices)
end