如何获得不包含某些索引的argmaxed火炬张量?

时间:2020-09-07 06:52:52

标签: pytorch

我想知道是否可以获取输入的torch.argmax(不包括某些索引)。 例如,

setSeriesShapesVisible

我想获得输入中的最大值(不包括目标上的索引),这样结果将是

target = torch.tensor([1,2])
input = torch.tensor([[0.1,0.5,0.2,0.2], [0.1,0.5,0.1,0.3]])

2 个答案:

答案 0 :(得分:3)

您可以尝试

  • 为温度张量中的target索引设置负infy
  • 然后使用torch.maxtorch.argmax
tmp_input = input.clone()
tmp_input[range(len(input)), target] = float("-Inf")

torch.max(tmp_input, dim=1).values
tensor([0.2000, 0.5000])

torch.max(tmp_input, dim=1).indices
tensor([3, 1])

torch.argmax(tmp_input, dim=1)
tensor([3, 1])

答案 1 :(得分:1)

input[target[0]-1,target[1]-1] = -1 # or use -inf 
#-1 is added for python indexing style
output  = torch.max(input,dim = 1)