PyTorch - 变量和张量之间的元素乘法?

时间:2017-07-02 21:35:24

标签: pytorch

截至PyTorch 0.4,此问题已不再有效。在0.4 Tensor s和Variable s合并。

如何在PyTorch中使用变量和张量进行逐元素乘法?有两个张量工作正常。变量和标量工作正常。但是当尝试使用变量和张量执行逐元素乘法时,我得到:

XXXXXXXXXXX in mul
    assert not torch.is_tensor(other)
AssertionError

例如,运行以下内容时:

import torch

x_tensor = torch.Tensor([[1, 2], [3, 4]])
y_tensor = torch.Tensor([[5, 6], [7, 8]])

x_variable = torch.autograd.Variable(x_tensor)

print(x_tensor * y_tensor)
print(x_variable * 2)
print(x_variable * y_tensor)

我希望第一个和最后一个打印语句显示类似的结果。前两个乘法按预期工作,错误在第三个中出现。我在PyTorch中尝试了*的别名(例如x_variable.mul(y_tensor)torch.mul(y_tensor, x_variable)等)。

考虑到错误和产生它的代码,似乎不支持张量和变量之间的元素乘法。它是否正确?还是有什么我想念的?谢谢!

1 个答案:

答案 0 :(得分:12)

是的,你是对的。元素乘法(与大多数其他操作一样)仅支持Tensor * TensorVariable * Variable,但不支持 Tensor * Variable

要执行上面的乘法运算,请将Tensor包裹为不需要渐变的Variable。额外的开销是微不足道的。

y_variable = torch.autograd.Variable(y_tensor, requires_grad=False)
x_variable * y_variable # returns Variable

但显然,只有使用Variables,如果你真的需要通过图表自动区分。另外,您可以像在问题中一样直接在Tensors上执行操作。