两个n维火炬张量的不相交

时间:2020-07-08 08:24:07

标签: python-3.x algorithm pytorch intersection

提前感谢大家的帮助!我想在PyTorch中做的是为许多维数的张量计算非相交(称其为torch.nonintersection)(没有for循环,因为我希望在GPU上有效执行)。所以这是示例,它应该如何工作:

a = torch.tensor([[ 0.,  0.], [ 0.,  1.], [ 0.,  2.], [ 1.,  0.], [ 1.,  1.], [ 1.,  2.], [ 1.,  3.], 
                  [ 2.,  0.], [ 2.,  1.], [ 2.,  2.]])
b = torch.tensor([[ 2.,  0.], [ 2.,  1.], [ 2.,  2.], [ 1.,  0.], [ 1.,  1.], [ 1.,  2.], [ 1.,  3.]])

torch.spec_unique(a,b) = torch.tensor([ 0.,  0.], [ 0.,  1.], [ 0.,  2.])

我有一些for循环的模拟物,但是它们现在花费了太多时间。任何想法如何做到这一点?非常感谢!

1 个答案:

答案 0 :(得分:0)

ab的形状在您的情况下是不同的(a(10, 2),而b(7, 2)),所以您必须将较大的张量“削减”到较小的张量。下面的函数使用简单的if处理它(不会减慢您的计算速度):

def non_intersection(a, b):
    if a.shape[0] > b.shape[0]:
        return torch.nonzero(a[: b.shape[0]] != b, as_tuple=False)
    return torch.nonzero(b[: a.shape[0]] != a, as_tuple=False)

这将返回:

tensor([[0, 0],
        [1, 0],
        [2, 0]])

因此列的顺序相反。如果您希望像示例一样获取列,则可以在non_intersection输出中执行以下操作:

torch.index_select(non_intersection(a, b), 1, torch.tensor([1, 0])))