我想知道是否可以获取输入的torch.argmax(不包括某些索引)。 例如,
setSeriesShapesVisible
我想获得输入中的最大值(不包括目标上的索引),这样结果将是
target = torch.tensor([1,2])
input = torch.tensor([[0.1,0.5,0.2,0.2], [0.1,0.5,0.1,0.3]])
答案 0 :(得分:3)
您可以尝试
target
索引设置负infy torch.max
或torch.argmax
tmp_input = input.clone()
tmp_input[range(len(input)), target] = float("-Inf")
torch.max(tmp_input, dim=1).values
tensor([0.2000, 0.5000])
torch.max(tmp_input, dim=1).indices
tensor([3, 1])
torch.argmax(tmp_input, dim=1)
tensor([3, 1])
答案 1 :(得分:1)
input[target[0]-1,target[1]-1] = -1 # or use -inf
#-1 is added for python indexing style
output = torch.max(input,dim = 1)