如何检查两个Torch张量或矩阵是否相等?

时间:2015-10-07 15:25:08

标签: lua torch

我需要一个Torch命令来检查两个张量是否具有相同的内容,如果它们具有相同的内容则返回TRUE。

例如:

local tens_a = torch.Tensor({9,8,7,6});
local tens_b = torch.Tensor({9,8,7,6});

if (tens_a EQUIVALENCE_COMMAND tens_b) then ... end

我应该在此脚本中使用什么而不是EQUIVALENCE_COMMAND

我只是尝试使用==,但它不起作用。

5 个答案:

答案 0 :(得分:21)

https://github.com/torch/torch7/blob/master/doc/maths.md#torcheqa-b

torch.eq(a, b)

实现==运算符比较a中的每个元素(如果b是数字)或b中的每个元素与b中的相应元素。

- UPDATE

来自@deltheil

torch.all(torch.eq(tens_a, tens_b))

甚至更简单

torch.all(tens_a:eq(tens_b))

答案 1 :(得分:5)

如果要忽略浮点数常见的小精度差异,请尝试此操作

torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-12))

答案 2 :(得分:4)

以下解决方案对我有用:

torch.equal(tensorA, tensorB)

来自the documentation

  

True,如果两个张量具有相同的大小和元素,否则为False

答案 3 :(得分:3)

要比较张量,可以按元素进行操作:

torch.eq是明智的选择:

torch.eq(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
tensor([[True, False], [False, True]])

或者torch.equal就是整个张量:

torch.equal(torch.tensor([[1., 2.], [3, 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
# False
torch.equal(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.], [3., 4.]]))
# True

但是您可能会迷失方向,因为有时您会忽略一些小的差异。例如,浮点数1.01.0000000001非常接近,您可能会认为它们是相等的。对于这种比较,您有torch.allclose

torch.allclose(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.000000001], [3., 4.]]))
# True

从某种意义上讲,比较元素的总数与元素数量相等时,检查每个元素相等是很重要的。如果您有两个张量dt1dt2,则将dt1的元素个数作为dt1.nelement()

使用此公式,您可以得到百分比:

print(torch.sum(torch.eq(dt1, dt2)).item()/dt1.nelement())

答案 4 :(得分:0)

您可以转换两个张量为numpy数组:

local tens_a = torch.Tensor((9,8,7,6));
local tens_b = torch.Tensor((9,8,7,6));

a=tens_a.numpy()
b=tens_b.numpy()

然后类似

np.sum(a==b)
4

可以让您对它们的平等程度有一个很好的了解。