PyTorch中的唯一张量值

时间:2017-07-09 05:34:58

标签: pytorch

我想在PyTorch张量中找到不同的值 是否有一些有效的方法来复制Tensorflow的unique op

3 个答案:

答案 0 :(得分:7)

0.4.0中有torch.unique()方法

mvn clean install -Ponline -DskipTests中,您可以尝试:

torch <= 0.3.1

答案 1 :(得分:5)

执行此操作的最佳方法(最简单方法)是转换为numpy并使用numpy的内置unique函数。像这样。

def unique(tensor1d):
    t, idx = np.unique(tensor1d.numpy(), return_inverse=True)
    return torch.from_numpy(t), torch.from_numpy(idx)  

所以,当你尝试时:

t, idx = unique(torch.LongTensor([1, 1, 2, 4, 4, 4, 7, 8, 8]))  
# t --> [1, 2, 4, 7, 8]
# idx --> [0, 0, 1, 2, 2, 2, 3, 4, 4]

答案 2 :(得分:0)

torch.unique()> 我们得到两个张量之间的共同项。等效于@ 2 tensor.eq()的获取索引和串联张量最终得到'torch.unique'的共同帮助。

import torch as pt

a = pt.tensor([1,2,3,2,3,4,3,4,5,6])
b = pt.tensor([7,2,3,2,7,4,9,4,9,8])

equal_data = pt.eq(a, b)
pt.unique(pt.cat([a[equal_data],b[equal_data]]))