Pytorch:如何使用割炬张量为查找表分配默认值

时间:2019-05-16 09:13:02

标签: pytorch

说我有两个张量,如下:

a = torch.tensor([[1, 2, 3], [1, 2, 3]])
b = torch.tensor([0, 2, 3, 4])

其中 b a 的查找值,例如:

b[a]

将返回以下值:

tensor([[2, 3, 4], [2, 3, 4]])

我的问题是,如果我只有一个查找表,该怎么办?

c = torch.tensor([0, 2, 3])

对于每个索引不足的地方,我希望将其分配给索引0,例如 c [a] 将返回

tensor([[2, 3, 0], [2, 3, 0]])

如果我运行 c [a] ,我当然会得到以下结果:

RuntimeError: index 3 is out of bounds for dim with size 3

感谢您的帮助。

1 个答案:

答案 0 :(得分:1)

  

代码

# replace values greater than a certain number
def custom_replace(tensor, value, on_value):
    # we create a copy of the original tensor, 
    # because of the way we are replacing them.
    res = tensor.clone()
    res[tensor>=value] = on_value
    return res

a = torch.tensor([[1, 2, 3], [1, 2, 3]])
c = torch.tensor([0, 2, 3])

a_ = custom_replace(a, c.size(0), 0)
print(c[a_])
  

输出

tensor([[2, 3, 0],
        [2, 3, 0]])