说我有两个张量,如下:
a = torch.tensor([[1, 2, 3], [1, 2, 3]])
b = torch.tensor([0, 2, 3, 4])
其中 b 是 a 的查找值,例如:
b[a]
将返回以下值:
tensor([[2, 3, 4], [2, 3, 4]])
我的问题是,如果我只有一个查找表,该怎么办?
c = torch.tensor([0, 2, 3])
对于每个索引不足的地方,我希望将其分配给索引0,例如 c [a] 将返回
tensor([[2, 3, 0], [2, 3, 0]])
如果我运行 c [a] ,我当然会得到以下结果:
RuntimeError: index 3 is out of bounds for dim with size 3
感谢您的帮助。
答案 0 :(得分:1)
代码
# replace values greater than a certain number
def custom_replace(tensor, value, on_value):
# we create a copy of the original tensor,
# because of the way we are replacing them.
res = tensor.clone()
res[tensor>=value] = on_value
return res
a = torch.tensor([[1, 2, 3], [1, 2, 3]])
c = torch.tensor([0, 2, 3])
a_ = custom_replace(a, c.size(0), 0)
print(c[a_])
输出
tensor([[2, 3, 0],
[2, 3, 0]])