PyTorch - 另一个张量中相应值的索引

时间:2021-01-23 14:24:19

标签: python pytorch

我有一个张量,我只想复制其中的一些值(按列)。相同的值位于另一个张量中,但顺序随机。我想要的是来自 tensor2 值的 tensor1 的列索引。下面是一个例子:

copy_ind = torch.tensor([0, 1, 3], dtype=torch.long)
tensor1 = torch.tensor([[4, 6, 5, 1, 8],[10, 0, 8, 2, 1]])
temp = torch.index_select(tensor1, 1, copy_ind) # values to copy
tensor2 = torch.tensor([[1, 4, 5, 6, 8],[2, 10, 8, 0, 1]], dtype=torch.long)
_, t_ind = torch.sort(temp[0], dim=0)
t2_ind = copy_ind[t_ind] # indices of tensor2

输出应该是:

t2_ind = [1, 3, 0]

这是我想根据 c1_new 获取张量值的另一个示例:

c1 = torch.tensor([[6, 7, 7, 8, 6, 8, 9, 4, 7, 6, 1, 3],[5, 11, 5, 7, 2, 9, 5, 5, 7, 11, 10, 7]], dtype=torch.long)
copy_ind = torch.tensor([1, 2, 3, 5, 7, 8], dtype=torch.long)
c1_new = torch.index_select(c1, 1, copy_ind)

indices = torch.as_tensor([[1, 3, 4, 6, 6, 6, 7, 7, 7, 8, 8, 9], [10, 7, 5, 2, 5, 11, 5, 7, 11, 7, 9, 5]])
values = torch.randn(12)
tensor = torch.sparse.FloatTensor(indices, values, (12, 12))

_, t_ind = torch.sort(c1[0], dim=0)
ind = t_ind[copy_ind] # should be [8, 6, 9, 10, 2, 7]

不幸的是,索引 ind 不正确。有人可以帮我吗?

1 个答案:

答案 0 :(得分:0)

如果您可以使用 for 循环,则可以使用以下方法:根据 tensor2 的列检查临时张量的每一列:

编辑:跨维度 1 使用 torch.prod 以确保两行匹配

[torch.prod((temp.T[i] == tesnor2.T), dim=1).nonzero()[0] for i in range(temp.size(1))]

我的第一个示例的输出是 [tensor(1), tensor(3), tensor(0)]