我需要一个Torch命令来检查两个张量是否具有相同的内容,如果它们具有相同的内容则返回TRUE。
例如:
local tens_a = torch.Tensor({9,8,7,6});
local tens_b = torch.Tensor({9,8,7,6});
if (tens_a EQUIVALENCE_COMMAND tens_b) then ... end
我应该在此脚本中使用什么而不是EQUIVALENCE_COMMAND
?
我只是尝试使用==
,但它不起作用。
答案 0 :(得分:21)
https://github.com/torch/torch7/blob/master/doc/maths.md#torcheqa-b
torch.eq(a, b)
实现==运算符比较a中的每个元素(如果b是数字)或b中的每个元素与b中的相应元素。
- UPDATE
来自@deltheil
torch.all(torch.eq(tens_a, tens_b))
甚至更简单
torch.all(tens_a:eq(tens_b))
答案 1 :(得分:5)
如果要忽略浮点数常见的小精度差异,请尝试此操作
torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-12))
答案 2 :(得分:4)
答案 3 :(得分:3)
要比较张量,可以按元素进行操作:
torch.eq
是明智的选择:
torch.eq(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
tensor([[True, False], [False, True]])
或者torch.equal
就是整个张量:
torch.equal(torch.tensor([[1., 2.], [3, 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
# False
torch.equal(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.], [3., 4.]]))
# True
但是您可能会迷失方向,因为有时您会忽略一些小的差异。例如,浮点数1.0
和1.0000000001
非常接近,您可能会认为它们是相等的。对于这种比较,您有torch.allclose
。
torch.allclose(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.000000001], [3., 4.]]))
# True
从某种意义上讲,比较元素的总数与元素数量相等时,检查每个元素相等是很重要的。如果您有两个张量dt1
和dt2
,则将dt1
的元素个数作为dt1.nelement()
使用此公式,您可以得到百分比:
print(torch.sum(torch.eq(dt1, dt2)).item()/dt1.nelement())
答案 4 :(得分:0)
您可以转换两个张量为numpy数组:
local tens_a = torch.Tensor((9,8,7,6));
local tens_b = torch.Tensor((9,8,7,6));
a=tens_a.numpy()
b=tens_b.numpy()
然后类似
np.sum(a==b)
4
可以让您对它们的平等程度有一个很好的了解。