我的张量形状为:
shape = [batch_size, d_0, d_1, ..., d_k]
和索引列表:
idx = [i_0, i_1, ..., i_k]
是否有一种方法可以用索引d_0, ..., d_k
有效地索引每个暗数i_0, ..., i_k
上的张量?
结果应为具有以下形状的张量:
shape = [batch_size]
k
仅在运行时可用
此刻,我正在创建一个元组o切片,每个维度一个:
idx = (slice(tensor.shape[0]),) + tuple(slice(i, i+1) for i in idx)
tensor[idx]
但这并不能说服我。
我想要什么:
tensor[:, *idx]
编辑:this answer无济于事,因为它不仅索引所有维度,而且不仅索引子集!
编辑:示例
a = torch.randint(0,10,[3,3,3,3])
indexes = torch.LongTensor([1,1,1])
我只想索引最后三个维度,例如:
a[:, indexes[0], indexes[1], indexes[2]]
但是在一般情况下,我不知道indexes
多长时间。
答案 0 :(得分:0)
不幸的是,您 can't provide1 混合了切片和迭代器到索引(例如 a[:,*idx]
)。但是,您可以通过将其包装在括号中以强制转换为迭代器来实现几乎相同的效果:
a[(slice(None), *idx)]
在 Python 中,x[(exp1, exp2, ..., expN)]
等价于 x[exp1, exp2, ..., expN]
;后者只是前者的语法糖。