import torch
a = torch.rand(5,256,120)
min_values, indices = torch.min(a,dim=0)
aa = torch.zeros(256,120)
for i in range(256):
for j in range(120):
aa[i,j] = a[indices[i,j],i,j]
print((aa==min_values).sum()==256*120)
我想知道如何避免使用for-for循环获取aa值? (我想使用索引来选择另一个3-d张量中的元素,因此我不能直接使用min返回的值)
答案 0 :(得分:1)
您可以使用torch.gather
aa = torch.gather(a, 0, indices.unsqueeze(0))