提前感谢大家的帮助!我要在PyTorch中尝试做的是numpy的setdiff1d
之类的事情。例如,给出以下两个张量:
t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')
预期输出应为(已排序或未排序):
torch.tensor([9, 12, 5])
理想情况下,操作是在GPU上完成的,并且在GPU和CPU之间不会来回移动。非常感谢!
答案 0 :(得分:2)
我遇到了同样的问题,但是当使用更大的数组时,提出的解决方案太慢了。以下简单的解决方案可在CPU和GPU上运行,并且比其他建议的解决方案快得多:
combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]
答案 1 :(得分:1)
如果您不需要for循环,则可以一次性比较所有值。
您也可以轻松获得非交叉点
t1 = torch.tensor([1, 9, 12, 5, 24])
t2 = torch.tensor([1, 24])
# Create a tensor to compare all values at once
compareview = t2.repeat(t1.shape[0],1).T
# Intersection
print(t1[(compareview == t1).T.sum(1)==1])
# Non Intersection
print(t1[(compareview != t1).T.prod(1)==1])
tensor([ 1, 24])
tensor([ 9, 12, 5])