截至PyTorch 0.4,此问题已不再有效。在0.4 Tensor
s和Variable
s合并。
如何在PyTorch中使用变量和张量进行逐元素乘法?有两个张量工作正常。变量和标量工作正常。但是当尝试使用变量和张量执行逐元素乘法时,我得到:
XXXXXXXXXXX in mul
assert not torch.is_tensor(other)
AssertionError
例如,运行以下内容时:
import torch
x_tensor = torch.Tensor([[1, 2], [3, 4]])
y_tensor = torch.Tensor([[5, 6], [7, 8]])
x_variable = torch.autograd.Variable(x_tensor)
print(x_tensor * y_tensor)
print(x_variable * 2)
print(x_variable * y_tensor)
我希望第一个和最后一个打印语句显示类似的结果。前两个乘法按预期工作,错误在第三个中出现。我在PyTorch中尝试了*
的别名(例如x_variable.mul(y_tensor)
,torch.mul(y_tensor, x_variable)
等)。
考虑到错误和产生它的代码,似乎不支持张量和变量之间的元素乘法。它是否正确?还是有什么我想念的?谢谢!
答案 0 :(得分:12)
是的,你是对的。元素乘法(与大多数其他操作一样)仅支持Tensor * Tensor
或Variable * Variable
,但不支持 Tensor * Variable
。
要执行上面的乘法运算,请将Tensor
包裹为不需要渐变的Variable
。额外的开销是微不足道的。
y_variable = torch.autograd.Variable(y_tensor, requires_grad=False)
x_variable * y_variable # returns Variable
但显然,只有使用Variables
,如果你真的需要通过图表自动区分。另外,您可以像在问题中一样直接在Tensors
上执行操作。