具有可变长度数组的索引多维火炬张量

时间:2020-08-07 22:26:41

标签: python indexing tuples pytorch tensor

我的张量形状为:

shape = [batch_size, d_0, d_1, ..., d_k]

和索引列表:

idx = [i_0, i_1, ..., i_k]

是否有一种方法可以用索引d_0, ..., d_k有效地索引每个暗数i_0, ..., i_k上的张量? 结果应为具有以下形状的张量:

shape = [batch_size]

k仅在运行时可用

此刻,我正在创建一个元组o切片,每个维度一个:

idx = (slice(tensor.shape[0]),) + tuple(slice(i, i+1) for i in idx)
tensor[idx]

但这并不能说服我。

我想要什么:

tensor[:, *idx]

编辑:this answer无济于事,因为它不仅索引所有维度,而且不仅索引子集!

编辑:示例

a = torch.randint(0,10,[3,3,3,3])
indexes = torch.LongTensor([1,1,1])

我只想索引最后三个维度,例如:

a[:, indexes[0], indexes[1], indexes[2]]

但是在一般情况下,我不知道indexes多长时间。

1 个答案:

答案 0 :(得分:0)

不幸的是,您 can't provide1 混合了切片和迭代器到索引(例如 a[:,*idx])。但是,您可以通过将其包装在括号中以强制转换为迭代器来实现几乎相同的效果:

a[(slice(None), *idx)]

  1. <块引用>

    在 Python 中,x[(exp1, exp2, ..., expN)] 等价于 x[exp1, exp2, ..., expN];后者只是前者的语法糖。

    https://numpy.org/doc/stable/reference/arrays.indexing.html