有没有一种方法可以将每个维度作为元素进行Pytorch元素明智的相等?

时间:2019-07-25 19:23:51

标签: python pytorch

我有两个张量,我想检查将一维数组作为元素的相等性

我有2个张量

lo = torch.Tensor(([1., 1., 0.],
                   [0., 1., 1.],
                   [0., 0., 0.],
                   [1., 1., 1.]))
lo = torch.Tensor(([1., 1., 0.],
                   [0., 0., 0.],
                   [0., 0., 0.],
                   [0., 0., 0.]))

我尝试使用 torch.eq(lee, lo) 返回张量之类的

tensor([[1, 1, 1],
        [1, 0, 0],
        [1, 1, 1],
        [0, 0, 0]], dtype=torch.uint8)

有没有办法使输出变为

tensor([1, 0, 1, 0])

第一个匹配的唯一完整元素是

编辑: 我想出了这个解决方案

lee = lee.tolist()
lo = lo.tolist()
out = []
for i, j in enumerate(lee):
  if j == lo[i]:
    out.append(1)
  else:
    out.append(0)

,输出将为[1、0、1、0] 但是有没有更简单的方法?

2 个答案:

答案 0 :(得分:1)

或者采用torch.eq(lee,lo),并且row必须求和为len,这意味着所有1必须存在

import torch
lo = torch.Tensor(([1., 1., 0.],
                   [0., 1., 1.],
                   [0., 0., 0.],
                   [1., 1., 1.]))
l1 = torch.Tensor(([1., 1., 0.],
                   [0., 0., 0.],
                   [0., 0., 0.],
                   [0., 0., 0.]))


teq = torch.eq(l1, lo) 

print(teq)

tsm =  teq.sum(-1)

print(tsm == 3)

tsm是张量([3,1,3,0]) 打印输出返回[1、0、1、0]

答案 1 :(得分:1)

您可以简单地使用torch.all(tensor, dim)

代码:

l1 = torch.Tensor(([1., 1., 0.],
                   [0., 1., 1.],
                   [0., 0., 0.],
                   [1., 1., 1.]))
l2 = torch.Tensor(([1., 1., 0.],
                   [0., 0., 0.],
                   [0., 0., 0.],
                   [0., 0., 0.]))
print(torch.eq(l1, l2))
print(torch.all(torch.eq(l1, l2),  dim=0)) # equivalent to dim = -2
print(torch.all(torch.eq(l1, l2),  dim=1)) # equivalent to dim = -1

输出:

tensor([[1, 1, 1],
        [1, 0, 0],
        [1, 1, 1],
        [0, 0, 0]], dtype=torch.uint8)
tensor([0, 0, 0], dtype=torch.uint8)
tensor([1, 0, 1, 0], dtype=torch.uint8) # your desired output