我正在尝试在特定的张量维度中获取并设置索引,如果可能的话,不进行重塑。我已经能够找到torch.index_select
函数,该函数可以在获取值时实现我想要的功能,但是我还没有找到类似的setter函数。是否存在?
对于上下文,我有一个张量和一组索引
class_energy = torch.rand(3, 10, 32, 32)
class_logits = torch.empty_like(class_energy)
idxs = [2, 3, 5, 7]
我想访问特定维度的索引处的项目,因此我可以执行log_softmax。
如果我知道先验暗淡,那么我可以简单地使用__getitem__
/ __setitem__
语法:例如,如果dim=1
,则class_energy[:, idxs]
。同样,如果dim=2
-> class_energy[:, :, idxs]
,dim=0
-> class_energy[idxs]
等...
在dim=1
的情况下,我本质上是这样想要的:
class_logits[:, idxs] = F.log_softmax(class_energy[:, idxs], dim=1)
不幸的是,我不知道dim
的值。当然,我可以通过以下方式提前建立幻想指数:
fancy_index = tuple([slice(None)] * dim + [idxs])
class_logits[fancy_index] = F.log_softmax(class_energy[fancy_index], dim=dim)
但是,我想知道是否有更好的方法可以做到这一点。对于__getitem__
,我知道有一个事实。以下使用torch.index_select
的代码是等效的
fancy_index = tuple([slice(None)] * dim + [idxs])
index = torch.LongTensor(idxs).to(class_energy.device)
class_logits[fancy_index] = F.log_softmax(torch.index_select(class_energy, dim=dim, index=index, dim=dim))
不仅功能上相同,而且index_select比使用花哨的getitem语法要快得多(我见过2倍的改进)。
我的问题是代码的__setitem__
部分似乎没有类似的功能。如果我能一起摆脱花哨的索引,那就太好了。我研究了Tensor.put_
,torch.index_put
和torch.select
,但是这些似乎都没有我想要的功能。是我缺少的东西,还是花式索引是当前解决此问题的唯一方法?