pytorch乘以4 * 1矩阵和1个大小的变量发生错误

时间:2018-04-20 09:06:19

标签: python pytorch

import torch
from torch.autograd import Variable
import numpy as np

x = np.transpose(np.array([[1, 2, 3, 4]]))
a = Variable(torch.rand(1), requires_grad=True)

print(a * x) # error!

我希望结果如x = [[2] [4] [6] [8]]如果a = 2

有什么解决方案吗?

1 个答案:

答案 0 :(得分:1)

您正在寻找的是矩阵乘法中的点标量积。

尝试:

x = np.transpose(np.array([[1, 2, 3, 4]]))
a = 2
x.dot(a)

输出矩阵[[2] [4] [6] [8]]