Pytorch:以特定的Tensor维设置索引(类似于torch.index_select)

时间:2018-10-31 19:20:50

标签: python pytorch

我正在尝试在特定的张量维度中获取并设置索引,如果可能的话,不进行重塑。我已经能够找到torch.index_select函数,该函数可以在获取值时实现我想要的功能,但是我还没有找到类似的setter函数。是否存在?

对于上下文,我有一个张量和一组索引

class_energy = torch.rand(3, 10, 32, 32)
class_logits = torch.empty_like(class_energy)
idxs = [2, 3, 5, 7]

我想访问特定维度的索引处的项目,因此我可以执行log_softmax。

如果我知道先验暗淡,那么我可以简单地使用__getitem__ / __setitem__语法:例如,如果dim=1,则class_energy[:, idxs]。同样,如果dim=2-> class_energy[:, :, idxs]dim=0-> class_energy[idxs]等...

dim=1的情况下,我本质上是这样想要的:

class_logits[:, idxs] = F.log_softmax(class_energy[:, idxs], dim=1)

不幸的是,我不知道dim的值。当然,我可以通过以下方式提前建立幻想指数:

fancy_index = tuple([slice(None)] * dim + [idxs])
class_logits[fancy_index] = F.log_softmax(class_energy[fancy_index], dim=dim)

但是,我想知道是否有更好的方法可以做到这一点。对于__getitem__,我知道有一个事实。以下使用torch.index_select的代码是等效的

fancy_index = tuple([slice(None)] * dim + [idxs])
index = torch.LongTensor(idxs).to(class_energy.device)
class_logits[fancy_index] = F.log_softmax(torch.index_select(class_energy, dim=dim, index=index, dim=dim))

不仅功能上相同,而且index_select比使用花哨的getitem语法要快得多(我见过2倍的改进)。

我的问题是代码的__setitem__部分似乎没有类似的功能。如果我能一起摆脱花哨的索引,那就太好了。我研究了Tensor.put_torch.index_puttorch.select,但是这些似乎都没有我想要的功能。是我缺少的东西,还是花式索引是当前解决此问题的唯一方法?

0 个答案:

没有答案