如何在PyTorch中对dot数据进行“点”加权?

时间:2017-11-13 22:32:19

标签: pytorch

我有批量数据,想要dot()数据。 W是可训练的参数。 如何在批量数据和权重之间进行点击?

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim) # assume trainable parameters via nn.Parameter
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)

更新

这个怎么样?

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = tdata.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim, 1) # assume trainable parameters via nn.Parameter
W = W.unsqueeze(0).expand(10, hid_dim, 1)
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)

1 个答案:

答案 0 :(得分:1)

展开W张量以匹配data张量的形状。以下应该有效。

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim)
W = W.unsqueeze(0).unsqueeze(0).expand(*data.size())
result = torch.sum(data * W, 2)
result = result.view(10, 2, 3)

修改:您的更新代码是正确的。由于您要将W转换为Bxhid_dimx1且数据的形状为Bxdxhid_dim,因此批量矩阵乘法将产生Bxdx1,这基本上是{{之间的点积1}}参数和Wdata)中的所有行向量。