所以,我在SO上遵循了此answer
我正试图使两个张量相等
torch.equal(x_valid[0], x_valid[:1])
返回False
,而
torch.all(torch.eq(x_valid[0], x_valid[:1]))
返回tensor(1, dtype=torch.uint8)
我知道两个张量都像x_valid
的第一个值一样,所以为什么torch.equal返回False
?
除了x_valid[0]
返回([0, 0, ...,0])
和x_valid[:1]
返回([[0, 0, ...,0]])
的事实
,但两者的类型仍为tensor
。因此,我无法真正理解为什么第一个查询的输出为False
答案 0 :(得分:0)
torch.equal(tensor1, tensor2)
返回True
,否则返回False
。选中here。
示例:
y = torch.tensor([[0, 0, 0]])
print(y[0], y[0].shape)
print(y[:1], y[:1].shape)
print(torch.equal(y[0], y[:1]))
print(torch.equal(y[0], y[:1][0])) # (torch.Size([3]), torch.Size([3]))
输出:
tensor([0, 0, 0]) torch.Size([3])
tensor([[0, 0, 0]]) torch.Size([1, 3])
False
True
而torch.eq(input, other, out=None)
计算按元素相等。这里,需要注意的一个重要问题是第二个参数可以是数字或张量,其第一个参数的形状为broadcastable。